diff --git a/.gitignore b/.gitignore index 83ef8ea38cd3..e24999ef0d5c 100644 --- a/.gitignore +++ b/.gitignore @@ -234,6 +234,9 @@ conda/pkg .envrc *.nix +# Docker files +.sudo_as_admin_successful + # Downloaded models/datasets .tvm_test_data .dgl diff --git a/CMakeLists.txt b/CMakeLists.txt index b499edd4560f..c56a929e276d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,7 +172,7 @@ else(MSVC) # ld option to warn if symbols are undefined (e.g. libtvm_runtime.so # using symbols only present in libtvm.so). Not needed for MSVC, # since this is already the default there. - if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin" OR ${CMAKE_SYSTEM_NAME} MATCHES "iOS") set(TVM_NO_UNDEFINED_SYMBOLS "-Wl,-undefined,error") else() set(TVM_NO_UNDEFINED_SYMBOLS "-Wl,--no-undefined") diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 3a231f083082..731b936ee6a2 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,33 +42,33 @@ We do encourage everyone to work anything they are interested in. - [Aditya Atluri](https://github.com/adityaatluri): @adityaatluri - rocm - [Matthew Barrett](https://github.com/mbaret): @mbaret - byoc, arm - [Matthew Brookhart](https://github.com/mbrookhart): @mbrookhart - relay, frontends -- [Tianqi Chen](https://github.com/tqchen) (PMC): @tqchen - topi, compiler, relay, docs - [Liangfu Chen](https://github.com/liangfu): @liangfu - vta, chisel, intel FPGA, c runtime +- [Tianqi Chen](https://github.com/tqchen) (PMC): @tqchen - topi, compiler, relay, docs - [Wei Chen](https://github.com/wweic): @wweic - runtime, relay, vm - [Zhi Chen](https://github.com/zhiics) (PMC): @zhiics - relay, quantization, pass manager -- [Chenfan](https://github.com/jcf94): @jcf94 - auto_scheduler - [Josh Fromm](https://github.com/jwfromm): @jwfromm - frontends, quantization, topi - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends - [Nick Hynes](https://github.com/nhynes): @nhynes: - sgx, rust - [Animesh Jain](https://github.com/anijain2305): @anijain2305 - quantization, relay +- [Chenfan Jia](https://github.com/jcf94): @jcf94 - auto_scheduler - [Ziheng Jiang](https://github.com/ZihengJiang) (PMC): @ZihengJiang - relay, compiler - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame - relay - [Wuwei Lin](https://github.com/vinx13): @vinx13 - relay, topi - [Yizhi Liu](https://github.com/yzhliu) (PMC): @yzhliu - jvm, topi, relay -- [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - relay - [Hao Lu](https://github.com/hlu1): @hlu1 - nnpack, frontends +- [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - relay - [Masahiro Masuda](https://github.com/masahi) (PMC): @masahi - topi, relay - [Thierry Moreau](https://github.com/tmoreau89) (PMC): @tmoreau89 - vta - [Kazutaka Morita](https://github.com/kazum): @kazum - frontends, opencl - [Trevor Morris](https://github.com/trevor-m): @trevor-m - byoc, compiler - [Leandro Nunes](https://github.com/leandron): @leandron - tvmc - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - hexagon, llvm -- [Andrew Reusch](https://github.com/areusch): @areusch - runtime, µTVM +- [Andrew Reusch](https://github.com/areusch): @areusch - runtime, microTVM - [Jared Roesch](https://github.com/jroesch) (PMC): @jroesch - relay - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - frontends -- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Junru Shao](https://github.com/junrushao1994) @junrushao1994 - relay, compiler - [Haichen Shen](https://github.com/icemelon9) (PMC): @icemelon9 - relay, topi +- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - topi, compiler, runtime - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - vta, chisel @@ -86,28 +86,25 @@ We do encourage everyone to work anything they are interested in. - [Matthew Barrett](https://github.com/mbaret): @mbaret - [Arnaud Bergeron](https://github.com/abergeron): @abergeron - [Matthew Brookhart](https://github.com/mbrookhart): @mbrookhart -- [Tianqi Chen](https://github.com/tqchen): @tqchen - [Liangfu Chen](https://github.com/liangfu): @liangfu +- [Tianqi Chen](https://github.com/tqchen): @tqchen - [Zhi Chen](https://github.com/zhiics): @zhiics -- [Chenfan](https://github.com/jcf94): @jcf94 - [Neo Chien](https://github.com/cchung100m): @cchung100m - [Meghan Cowan](https://github.com/cowanmeg): @cowanmeg - [Balint Cristian](https://github.com/cbalint13): @cbalint13 +- [Egor Churaev](https://github.com/echuraev): @echuraev - metal +- [Xiaoqiang Dan](https://github.com/xqdan): @xqdan - [Haozheng Fan](https://github.com/hzfan): @hzfan -- [Josh Fromm](https://github.com/jwfromm): @jwfromm - [Siyuan Feng](https://github.com/Hzfengsy): @Hzfengsy +- [Josh Fromm](https://github.com/jwfromm): @jwfromm - [Sergei Grechanik](https://github.com/sgrechanik-h): @sgrechanik-h -- [Hao Lu](https://github.com/hlu1): @hlu1 - [Bohan Hou](https://github.com/spectrometerHBH): @spectrometerHBH -- [Nick Hynes](https://github.com/nhynes): @nhynes - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - [Luke Hutton](https://github.com/lhutton1): @lhutton1 +- [Nick Hynes](https://github.com/nhynes): @nhynes - [Animesh Jain](https://github.com/anijain2305): @anijain2305 +- [Chenfan Jia](https://github.com/jcf94): @jcf94 - [Hua Jiang](https://github.com/huajsj): @huajsj -- [Leandro Nunes](https://github.com/leandron): @leandron -- [Yizhi Liu](https://github.com/yzhliu) : @yzhliu -- [Zhixun Tan](https://github.com/phisiart): @phisiart -- [Xiaoqiang Dan](https://github.com/xqdan): @xqdan - [Ziheng Jiang](https://github.com/ZihengJiang): @ZihengJiang - [Manupa Karunaratne](https://github.com/manupa-arm): @manupa-arm - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame @@ -116,6 +113,8 @@ We do encourage everyone to work anything they are interested in. - [Andrew Liu](https://github.com/hypercubestart): @hypercubestart - [Henry Liu](https://github.com/optima2005): @optima2005 - [Xin Liu](https://github.com/Meteorix): @Meteorix +- [Yizhi Liu](https://github.com/yzhliu) : @yzhliu +- [Hao Lu](https://github.com/hlu1): @hlu1 - [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - [Masahiro Masuda](https://github.com/masahi): @masahi - [Sergey Mironov](https://github.com/grwlf): @grwlf @@ -123,26 +122,28 @@ We do encourage everyone to work anything they are interested in. - [Kazutaka Morita](https://github.com/kazum): @kazum - [Trevor Morris](https://github.com/trevor-m): @trevor-m - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t +- [Leandro Nunes](https://github.com/leandron): @leandron - [Wei Pan](https://github.com/wpan11nv): @wpan11nv - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 - [Josh Pollock](https://github.com/joshpoll): @joshpoll +- [Andrew Reusch](https://github.com/areusch): @areusch - [Jared Roesch](https://github.com/jroesch): @jroesch - [Giuseppe Rossini](https://github.com/giuseros): @giuseros -- [Andrew Reusch](https://github.com/areusch): @areusch -- [Dmitriy Smirnov](https://github.com/d-smirnov): @d-smirnov -- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - [Junru Shao](https://github.com/junrushao1994): @junrushao1994 - [Haichen Shen](https://github.com/icemelon9): @icemelon9 - [Xingjian Shi](https://github.com/sxjscience): @sxjscience +- [Siva](https://github.com/srkreddy1238): @srkreddy1238 +- [Dmitriy Smirnov](https://github.com/d-smirnov): @d-smirnov - [Jon Soifer](https://github.com/soiferj): @soiferj +- [Zhixun Tan](https://github.com/phisiart): @phisiart - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - [Thomas Viehmann](https://github.com/t-vi): @t-vi -- [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Yao Wang](https://github.com/kevinthesun): @kevinthesun - [Leyuan Wang](https://github.com/Laurawly): @Laurawly +- [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Logan Weber](https://github.com/weberlo): @weberlo - [Jian Weng](https://github.com/were): @were - [Yong Wu](https://github.com/yongwww): @yongwww @@ -155,9 +156,3 @@ We do encourage everyone to work anything they are interested in. ## List of Contributors - [Full List of Contributors](https://github.com/apache/tvm/graphs/contributors) - - To contributors: please add your name to the list. -- [Qiao Zhang](https://github.com/zhangqiaorjc) -- [Haolong Zhang](https://github.com/haolongzhangm) -- [Cody Hao Yu](https://github.com/comaniac) -- [Chris Nuernberger](https://github.com/cnuernber) -- [Shoubhik Bhattacharya](https://github.com/shoubhik) diff --git a/NEWS.md b/NEWS.md index a5da068c895c..c1f0276ee713 100644 --- a/NEWS.md +++ b/NEWS.md @@ -36,7 +36,7 @@ v0.7 brings many major features. The community works together to refactor the in * Intial Hexagon support * Bring your own codegen (BYOC) support -The community also continues to bring high quality improvements to the existing modules including, but not limited to: better frontend coverage, performance, quantization, uTVM and dynamic shape support. +The community also continues to bring high quality improvements to the existing modules including, but not limited to: better frontend coverage, performance, quantization, microTVM and dynamic shape support. ## New Features ### Automatic Scheduling (Experimental) diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index c0bd7070412a..1331e1a65ca8 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -62,6 +62,7 @@ #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" #include "../src/runtime/opencl/opencl_module.cc" +#include "../src/runtime/opencl/texture_pool.cc" #include "../src/runtime/source_utils.cc" #endif diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 5f703e1dc2b0..e897a975de28 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -39,6 +39,7 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } #include #include #include + #include "../../src/support/utils.h" #include "rpc_env.h" @@ -95,7 +96,16 @@ RPCEnv::RPCEnv(const std::string& wd) { auto cmdline = fopen("/proc/self/cmdline", "r"); fread(cwd, 1, sizeof(cwd), cmdline); fclose(cmdline); - base_ = "/data/data/" + std::string(cwd) + "/cache/rpc"; + std::string android_base_ = "/data/data/" + std::string(cwd) + "/cache"; + struct stat statbuf; + // Check if application data directory exist. If not exist, usually means we run tvm_rpc from + // adb shell terminal. + if (stat(android_base_.data(), &statbuf) == -1 || !S_ISDIR(statbuf.st_mode)) { + // Tmp directory is always writable for 'shell' user. + android_base_ = "/data/local/tmp"; + } + base_ = android_base_ + "/rpc"; + #elif !defined(_WIN32) char cwd[PATH_MAX]; if (getcwd(cwd, sizeof(cwd))) { diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 52b5da965b4c..5dc84105388b 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -22,7 +22,8 @@ * \brief RPC Server implementation. */ #include -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) +#include #include #include #endif @@ -52,7 +53,7 @@ namespace runtime { * \brief wait the child process end. * \param status status value */ -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) static pid_t waitPidEintr(int* status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { @@ -162,7 +163,7 @@ class RPCServer { } int timeout = GetTimeOutFromOpts(opts); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); @@ -203,7 +204,7 @@ class RPCServer { auto pid = fork(); if (pid == 0) { ServerLoopProc(conn, addr, work_dir_); - exit(0); + _exit(0); } // Wait for the result int status = 0; @@ -219,6 +220,10 @@ class RPCServer { auto dur = high_resolution_clock::now() - start_time; LOG(INFO) << "Serve Time " << duration_cast(dur).count() << "ms"; +#else + LOG(WARNING) << "Unknown platform. It is not known how to bring up the subprocess." + << " RPC will be launched in the main thread."; + ServerLoopProc(conn, addr, work_dir_); #endif // close from our side. LOG(INFO) << "Socket Connection Closed"; diff --git a/apps/microtvm/reference-vm/README.md b/apps/microtvm/reference-vm/README.md index 7ef7900c3e05..7ff75c75b4f9 100644 --- a/apps/microtvm/reference-vm/README.md +++ b/apps/microtvm/reference-vm/README.md @@ -49,19 +49,19 @@ Reference VMs are organized as follows: ## Creating Releases -1. Build the base box for the given platform: `$ ./base-box-tool.py build ` +1. Build the base box for the given platform: `$ ./base-box-tool.py [--provider=] build ` 2. Run release tests for each platform: 1. Connect any needed hardware to the VM host machine. - 2. Run tests: `$ ./base-box-tool.py test [--test-device-serial=]`. This + 2. Run tests: `$ ./base-box-tool.py [--provider=] test [--microtvm-platform=] [--test-device-serial=]`. This command does the following for each provider: 1. Copies all files inside `./` except `.vagrant` and `base-box` to `./release-test`. This is done to avoid reusing any VM the developer may have started. - 2. Executes `$ vagrant up --provider=`. + 2. Executes `$ vagrant up [--provider=]`. 3. Finds an attached USB device matching the VID and PID specified in `test-config.json`, and if `--test-device-serial` was given, that serial number (as reported to USB). Creates a rule to autoconnect this device to the VM, and also attaches it to the VM> 4. SSHs to the VM, `cd` to the TVM root directory, and runs `test_cmd` from `test-config.json`. Nonzero status means failure. 3. If release tests fail, fix them and restart from step 1. -4. If release tests pass: `$ ./base-box-tool.py release `. Be sure you've logged +4. If release tests pass: `$ ./base-box-tool.py [--provider=] release <--release-version=> <--platform-version=> `. Be sure you've logged in to Vagrant Cloud using the `vagrant` tool. diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index fb7a9c0b5ce6..c22eff4cdbad 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -34,7 +34,6 @@ THIS_DIR = os.path.realpath(os.path.dirname(__file__) or ".") - # List of vagrant providers supported by this tool ALL_PROVIDERS = ( "parallels", @@ -46,8 +45,11 @@ ALL_MICROTVM_PLATFORMS = ( "stm32f746xx", "nrf5340dk", + "mps2_an521", ) +PACKER_FILE_NAME = "packer.json" + def parse_virtualbox_devices(): output = subprocess.check_output(["VBoxManage", "list", "usbhost"], encoding="utf-8") @@ -173,12 +175,21 @@ def attach_vmware(uuid, vid_hex=None, pid_hex=None, serial=None): "vmware_desktop": attach_vmware, } +# Extra scripts required to execute on provisioning +# in zephyr/base-box/base_box_provision.sh +EXTRA_SCRIPTS = ( + "docker/install/ubuntu_init_zephyr_project.sh", + "docker/install/ubuntu_install_qemu.sh", +) + def generate_packer_config(file_path, providers): builders = [] + provisioners = [] for provider_name in providers: builders.append( { + "name": f"{provider_name}", "type": "vagrant", "box_name": f"microtvm-base-{provider_name}", "output_dir": f"output-packer-{provider_name}", @@ -189,10 +200,26 @@ def generate_packer_config(file_path, providers): } ) + repo_root = subprocess.check_output( + ["git", "rev-parse", "--show-toplevel"], cwd=os.path.dirname(__file__), encoding="utf-8" + ).strip() + for script in EXTRA_SCRIPTS: + script_path = os.path.join(repo_root, script) + filename = os.path.basename(script_path) + provisioners.append({"type": "file", "source": script_path, "destination": f"~/{filename}"}) + + provisioners.append( + { + "type": "shell", + "script": "base_box_provision.sh", + } + ) + with open(file_path, "w") as f: json.dump( { "builders": builders, + "provisioners": provisioners, }, f, sort_keys=True, @@ -202,7 +229,7 @@ def generate_packer_config(file_path, providers): def build_command(args): generate_packer_config( - os.path.join(THIS_DIR, args.platform, "base-box", "packer.json"), + os.path.join(THIS_DIR, args.platform, "base-box", PACKER_FILE_NAME), args.provider or ALL_PROVIDERS, ) env = copy.copy(os.environ) @@ -212,7 +239,7 @@ def build_command(args): if args.debug_packer: packer_args += ["-debug"] - packer_args += ["packer.json"] + packer_args += [PACKER_FILE_NAME] subprocess.check_call( packer_args, cwd=os.path.join(THIS_DIR, args.platform, "base-box"), env=env ) @@ -221,7 +248,6 @@ def build_command(args): REQUIRED_TEST_CONFIG_KEYS = { "vid_hex": str, "pid_hex": str, - "test_cmd": list, } @@ -284,7 +310,6 @@ def do_build_release_test_vm(release_test_dir, user_box_dir, base_box_dir, provi return_code = subprocess.call(remove_args, cwd=release_test_dir) assert return_code in (0, 1), f'{" ".join(remove_args)} returned exit code {return_code}' subprocess.check_call(["vagrant", "up", f"--provider={provider_name}"], cwd=release_test_dir) - return True @@ -293,18 +318,30 @@ def do_run_release_test(release_test_dir, provider_name, test_config, test_devic os.path.join(release_test_dir, ".vagrant", "machines", "default", provider_name, "id") ) as f: machine_uuid = f.read() - ATTACH_USB_DEVICE[provider_name]( - machine_uuid, - vid_hex=test_config["vid_hex"], - pid_hex=test_config["pid_hex"], - serial=test_device_serial, - ) + + # Check if target is not QEMU + if test_config["vid_hex"] and test_config["pid_hex"]: + ATTACH_USB_DEVICE[provider_name]( + machine_uuid, + vid_hex=test_config["vid_hex"], + pid_hex=test_config["pid_hex"], + serial=test_device_serial, + ) tvm_home = os.path.realpath(os.path.join(THIS_DIR, "..", "..", "..")) def _quote_cmd(cmd): return " ".join(shlex.quote(a) for a in cmd) - test_cmd = _quote_cmd(["cd", tvm_home]) + " && " + _quote_cmd(test_config["test_cmd"]) + test_cmd = ( + _quote_cmd(["cd", tvm_home]) + + " && " + + _quote_cmd( + [ + "apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh", + test_config["microtvm_platform"], + ] + ) + ) subprocess.check_call(["vagrant", "ssh", "-c", f"bash -ec '{test_cmd}'"], cwd=release_test_dir) @@ -325,6 +362,7 @@ def test_command(args): microtvm_test_platform["vid_hex"] = microtvm_test_platform["vid_hex"].lower() microtvm_test_platform["pid_hex"] = microtvm_test_platform["pid_hex"].lower() + microtvm_test_platform["microtvm_platform"] = args.microtvm_platform providers = args.provider provider_passed = {p: False for p in providers} @@ -399,18 +437,18 @@ def parse_args(): description="Automates building, testing, and releasing a base box" ) subparsers = parser.add_subparsers(help="Action to perform.") - parser.add_argument( - "platform", - help="Name of the platform VM to act on. Must be a sub-directory of this directory.", - ) parser.add_argument( "--provider", choices=ALL_PROVIDERS, action="append", - default=list(ALL_PROVIDERS), help="Name of the provider or providers to act on; if not specified, act on all.", ) + parser.add_argument( + "platform", + help="Name of the platform VM to act on. Must be a sub-directory of this directory.", + ) + parser_build = subparsers.add_parser("build", help="Build a base box.") parser_build.set_defaults(func=build_command) parser_test = subparsers.add_parser("test", help="Test a base box before release.") diff --git a/apps/microtvm/reference-vm/zephyr/Vagrantfile b/apps/microtvm/reference-vm/zephyr/Vagrantfile index 2778d7ca8a49..be41c0b733e5 100644 --- a/apps/microtvm/reference-vm/zephyr/Vagrantfile +++ b/apps/microtvm/reference-vm/zephyr/Vagrantfile @@ -46,7 +46,7 @@ Vagrant.configure("2") do |config| end end - config.vm.provision "shell", path: "setup.sh", env: {"TVM_HOME": dirs_to_mount[0]}, privileged: false + config.vm.provision "shell", path: "provision_setup.sh", env: {"TVM_HOME": dirs_to_mount[0]}, privileged: false # Enable USB Controller on VirtualBox vm_name = "microtvm-#{Time.now.tv_sec}" diff --git a/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template b/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template index 38f9a20b56cf..b43596bb83c1 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template +++ b/apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template @@ -41,7 +41,7 @@ Vagrant.configure("2") do |config| config.vm.provision "shell", inline: "touch ~/skip_zeroing_disk", privileged: false {{- end}} - # NOTE: setup.sh resides in the parent directory (../) because this template is expanded into a + # NOTE: base_box_setup.sh resides in the parent directory (../) because this template is expanded into a # sub-directory of base-box (output-packer-*). - config.vm.provision "shell", path: "../setup.sh", privileged: false + config.vm.provision "shell", path: "../base_box_setup.sh", privileged: false end diff --git a/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh new file mode 100644 index 000000000000..69e6171d06dd --- /dev/null +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh @@ -0,0 +1,37 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Using this script we can reuse docker/install scripts to configure the reference +# virtual machine similar to CI QEMU setup. +# + +set -e +set -x + +source ~/.profile + +# Init Zephyr +cd ~ +# Using most recent commit that passes all the tests. +~/ubuntu_init_zephyr_project.sh ~/zephyr v2.5-branch --commit dabf23758417fd041fec2a2a821d8f526afac29d + +# Build QEMU +sudo ~/ubuntu_install_qemu.sh --target-list arm-softmmu + +# Cleanup +rm -f *.sh diff --git a/apps/microtvm/reference-vm/zephyr/base-box/setup.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_setup.sh similarity index 97% rename from apps/microtvm/reference-vm/zephyr/base-box/setup.sh rename to apps/microtvm/reference-vm/zephyr/base-box/base_box_setup.sh index 8f7ed41af337..e8385af9f663 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/setup.sh +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_setup.sh @@ -17,6 +17,7 @@ # under the License. set -e +set -x skip_zeroing_disk=0 if [ -e "$HOME/skip_zeroing_disk" ]; then @@ -81,8 +82,6 @@ pip3 install --user -U west echo 'export PATH=$HOME/.local/bin:"$PATH"' >> ~/.profile source ~/.profile echo PATH=$PATH -REPO_ROOT=$(git rev-parse --show-toplevel) -${REPO_ROOT}/docker/install/ubuntu_init_zephyr_project.sh ~/zephyr v2.5.0 cd ~ echo "Downloading zephyr SDK..." diff --git a/apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh new file mode 100755 index 000000000000..8eba63e9e331 --- /dev/null +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_test.sh @@ -0,0 +1,39 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Usage: base_box_test.sh +# Execute microTVM Zephyr tests. +# + +set -e +set -x + +if [ "$#" -lt 1 ]; then + echo "Usage: base_box_test.sh " + exit -1 +fi + +microtvm_platform=$1 + +pytest tests/micro/zephyr/test_zephyr.py --microtvm-platforms=${microtvm_platform} + +if [ $microtvm_platform == "stm32f746xx" ]; then + echo "NOTE: skipped test_zephyr_aot.py on $microtvm_platform -- known failure" +else + pytest tests/micro/zephyr/test_zephyr_aot.py --microtvm-platforms=${microtvm_platform} +fi diff --git a/apps/microtvm/reference-vm/zephyr/base-box/test-config.json b/apps/microtvm/reference-vm/zephyr/base-box/test-config.json index 1a39d34c7e64..48b6915a10f4 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/test-config.json +++ b/apps/microtvm/reference-vm/zephyr/base-box/test-config.json @@ -1,12 +1,14 @@ { "stm32f746xx": { "vid_hex": "0483", - "pid_hex": "374b", - "test_cmd": ["pytest", "tests/micro/zephyr/test_zephyr.py", "--microtvm-platforms=stm32f746xx"] + "pid_hex": "374b" }, "nrf5340dk": { "vid_hex": "1366", - "pid_hex": "1055", - "test_cmd": ["pytest", "tests/micro/zephyr/test_zephyr.py", "--microtvm-platforms=nrf5340dk"] + "pid_hex": "1055" + }, + "mps2_an521": { + "vid_hex": "", + "pid_hex": "" } } diff --git a/apps/microtvm/reference-vm/zephyr/setup.sh b/apps/microtvm/reference-vm/zephyr/provision_setup.sh similarity index 95% rename from apps/microtvm/reference-vm/zephyr/setup.sh rename to apps/microtvm/reference-vm/zephyr/provision_setup.sh index e0f382cfc23e..f95c7e24f5aa 100644 --- a/apps/microtvm/reference-vm/zephyr/setup.sh +++ b/apps/microtvm/reference-vm/zephyr/provision_setup.sh @@ -24,6 +24,7 @@ cd "${TVM_HOME}" apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh +# Build poetry cd apps/microtvm/reference-vm/zephyr poetry env use 3.6 @@ -41,7 +42,7 @@ echo "downloaded and cached for future use." echo "------------------------------[ TVM Message ]------------------------------" poetry lock -vvv poetry install -poetry run pip3 install -r ~/zephyr/zephyr/scripts/requirements.txt +poetry run pip3 install -r ${ZEPHYR_BASE}/scripts/requirements.txt echo "export TVM_LIBRARY_PATH=\"$TVM_HOME\"/build-microtvm" >>~/.profile echo "VENV_PATH=\$((cd \"$TVM_HOME\"/apps/microtvm/reference-vm/zephyr && poetry env list --full-path) | sed -E 's/^(.*)[[:space:]]\(Activated\)\$/\1/g')" >>~/.profile diff --git a/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh b/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh index 2eb55e385520..1cebcf7166af 100755 --- a/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh +++ b/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh @@ -18,6 +18,14 @@ set -e +# Get number of cores for build +if [ -n "${TVM_CI_NUM_CORES}" ]; then + num_cores=${TVM_CI_NUM_CORES} +else + # default setup for Vagrantfile + num_cores=2 +fi + cd "$(dirname $0)" cd "$(git rev-parse --show-toplevel)" BUILD_DIR=build-microtvm @@ -32,4 +40,4 @@ sed -i 's/USE_GRAPH_EXECUTOR_DEBUG OFF/USE_GRAPH_EXECUTOR_DEBUG ON/' config.cmak sed -i 's/USE_LLVM OFF/USE_LLVM ON/' config.cmake cmake .. rm -rf standalone_crt host_standalone_crt # remove stale generated files -make -j4 +make -j${num_cores} diff --git a/apps/microtvm/zephyr/aot_demo/src/main.c b/apps/microtvm/zephyr/aot_demo/src/main.c index b92366a7098b..7ee812ffc33e 100644 --- a/apps/microtvm/zephyr/aot_demo/src/main.c +++ b/apps/microtvm/zephyr/aot_demo/src/main.c @@ -83,26 +83,26 @@ void timer_expiry_function(struct k_timer* timer_id) { return; } #define MILLIS_TIL_EXPIRY 200 #define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY)) -struct k_timer g_utvm_timer; -uint32_t g_utvm_start_time; -int g_utvm_timer_running = 0; +struct k_timer g_microtvm_timer; +uint32_t g_microtvm_start_time; +int g_microtvm_timer_running = 0; // Called to start system timer. tvm_crt_error_t TVMPlatformTimerStart() { - if (g_utvm_timer_running) { + if (g_microtvm_timer_running) { TVMLogf("timer already running"); return kTvmErrorPlatformTimerBadState; } - k_timer_start(&g_utvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); - g_utvm_start_time = k_cycle_get_32(); - g_utvm_timer_running = 1; + k_timer_start(&g_microtvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); + g_microtvm_start_time = k_cycle_get_32(); + g_microtvm_timer_running = 1; return kTvmErrorNoError; } // Called to stop system timer. tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_utvm_timer_running) { + if (!g_microtvm_timer_running) { TVMLogf("timer not running"); return kTvmErrorSystemErrorMask | 2; } @@ -110,11 +110,11 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { uint32_t stop_time = k_cycle_get_32(); // compute how long the work took - uint32_t cycles_spent = stop_time - g_utvm_start_time; - if (stop_time < g_utvm_start_time) { + uint32_t cycles_spent = stop_time - g_microtvm_start_time; + if (stop_time < g_microtvm_start_time) { // we rolled over *at least* once, so correct the rollover it was *only* // once, because we might still use this result - cycles_spent = ~((uint32_t)0) - (g_utvm_start_time - stop_time); + cycles_spent = ~((uint32_t)0) - (g_microtvm_start_time - stop_time); } uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent); @@ -122,13 +122,13 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { // need to grab time remaining *before* stopping. when stopped, this function // always returns 0. - int32_t time_remaining_ms = k_timer_remaining_get(&g_utvm_timer); - k_timer_stop(&g_utvm_timer); + int32_t time_remaining_ms = k_timer_remaining_get(&g_microtvm_timer); + k_timer_stop(&g_microtvm_timer); // check *after* stopping to prevent extra expiries on the happy path if (time_remaining_ms < 0) { return kTvmErrorSystemErrorMask | 3; } - uint32_t num_expiries = k_timer_status_get(&g_utvm_timer); + uint32_t num_expiries = k_timer_status_get(&g_microtvm_timer); uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms); double approx_num_cycles = (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms); @@ -140,7 +140,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { *elapsed_time_seconds = hw_clock_res_us / 1e6; } - g_utvm_timer_running = 0; + g_microtvm_timer_running = 0; return kTvmErrorNoError; } @@ -172,7 +172,7 @@ void main(void) { g_cmd_buf_ind = 0; memset((char*)cmd_buf, 0, sizeof(cmd_buf)); TVMPlatformUARTInit(); - k_timer_init(&g_utvm_timer, NULL, NULL); + k_timer_init(&g_microtvm_timer, NULL, NULL); // Wake up host side. TVMPlatformWriteSerial(g_wakeup_sequence, sizeof(g_wakeup_sequence)); diff --git a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c b/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c index 1f4dde1de4b9..c9eec8751100 100644 --- a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c +++ b/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c @@ -23,7 +23,7 @@ #include "crt_config.h" -static const struct device* g_utvm_uart; +static const struct device* g_microtvm_uart; #define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100) // Ring buffer used to store data read from the UART on rx interrupt. @@ -68,7 +68,7 @@ uint32_t TVMPlatformUartRxRead(uint8_t* data, uint32_t data_size_bytes) { uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size) { for (uint32_t i = 0; i < size; i++) { - uart_poll_out(g_utvm_uart, data[i]); + uart_poll_out(g_microtvm_uart, data[i]); } return size; } @@ -76,6 +76,6 @@ uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size) { // Initialize UART void TVMPlatformUARTInit() { // Claim console device. - g_utvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console))); - uart_rx_init(&uart_rx_rbuf, g_utvm_uart); + g_microtvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console))); + uart_rx_init(&uart_rx_rbuf, g_microtvm_uart); } diff --git a/apps/microtvm/zephyr/host_driven/src/main.c b/apps/microtvm/zephyr/host_driven/src/main.c index 637a58ae92fd..5b93d647eb00 100644 --- a/apps/microtvm/zephyr/host_driven/src/main.c +++ b/apps/microtvm/zephyr/host_driven/src/main.c @@ -39,7 +39,7 @@ #include #include #include -#include +#include #include #include @@ -146,14 +146,14 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { #define MILLIS_TIL_EXPIRY 200 #define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY)) -K_TIMER_DEFINE(g_utvm_timer, /* expiry func */ NULL, /* stop func */ NULL); +K_TIMER_DEFINE(g_microtvm_timer, /* expiry func */ NULL, /* stop func */ NULL); -uint32_t g_utvm_start_time; -int g_utvm_timer_running = 0; +uint32_t g_microtvm_start_time; +int g_microtvm_timer_running = 0; // Called to start system timer. tvm_crt_error_t TVMPlatformTimerStart() { - if (g_utvm_timer_running) { + if (g_microtvm_timer_running) { TVMLogf("timer already running"); return kTvmErrorPlatformTimerBadState; } @@ -161,15 +161,15 @@ tvm_crt_error_t TVMPlatformTimerStart() { #ifdef CONFIG_LED gpio_pin_set(led0_pin, LED0_PIN, 1); #endif - k_timer_start(&g_utvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); - g_utvm_start_time = k_cycle_get_32(); - g_utvm_timer_running = 1; + k_timer_start(&g_microtvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); + g_microtvm_start_time = k_cycle_get_32(); + g_microtvm_timer_running = 1; return kTvmErrorNoError; } // Called to stop system timer. tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_utvm_timer_running) { + if (!g_microtvm_timer_running) { TVMLogf("timer not running"); return kTvmErrorSystemErrorMask | 2; } @@ -180,11 +180,11 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { #endif // compute how long the work took - uint32_t cycles_spent = stop_time - g_utvm_start_time; - if (stop_time < g_utvm_start_time) { + uint32_t cycles_spent = stop_time - g_microtvm_start_time; + if (stop_time < g_microtvm_start_time) { // we rolled over *at least* once, so correct the rollover it was *only* // once, because we might still use this result - cycles_spent = ~((uint32_t)0) - (g_utvm_start_time - stop_time); + cycles_spent = ~((uint32_t)0) - (g_microtvm_start_time - stop_time); } uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent); @@ -192,14 +192,14 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { // need to grab time remaining *before* stopping. when stopped, this function // always returns 0. - int32_t time_remaining_ms = k_timer_remaining_get(&g_utvm_timer); - k_timer_stop(&g_utvm_timer); + int32_t time_remaining_ms = k_timer_remaining_get(&g_microtvm_timer); + k_timer_stop(&g_microtvm_timer); // check *after* stopping to prevent extra expiries on the happy path if (time_remaining_ms < 0) { TVMLogf("negative time remaining"); return kTvmErrorSystemErrorMask | 3; } - uint32_t num_expiries = k_timer_status_get(&g_utvm_timer); + uint32_t num_expiries = k_timer_status_get(&g_microtvm_timer); uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms); double approx_num_cycles = (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms); @@ -211,7 +211,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { *elapsed_time_seconds = hw_clock_res_us / 1e6; } - g_utvm_timer_running = 0; + g_microtvm_timer_running = 0; return kTvmErrorNoError; } @@ -285,14 +285,14 @@ void main(void) { uart_rx_init(&uart_rx_rbuf, tvm_uart); // Initialize microTVM RPC server, which will receive commands from the UART and execute them. - utvm_rpc_server_t server = UTvmRpcServerInit(write_serial, NULL); + microtvm_rpc_server_t server = MicroTVMRpcServerInit(write_serial, NULL); TVMLogf("microTVM Zephyr runtime - running"); #ifdef CONFIG_LED gpio_pin_set(led0_pin, LED0_PIN, 0); #endif // The main application loop. We continuously read commands from the UART - // and dispatch them to UTvmRpcServerLoop(). + // and dispatch them to MicroTVMRpcServerLoop(). while (true) { uint8_t* data; unsigned int key = irq_lock(); @@ -302,7 +302,7 @@ void main(void) { size_t bytes_remaining = bytes_read; while (bytes_remaining > 0) { // Pass the received bytes to the RPC server. - tvm_crt_error_t err = UTvmRpcServerLoop(server, &data, &bytes_remaining); + tvm_crt_error_t err = MicroTVMRpcServerLoop(server, &data, &bytes_remaining); if (err != kTvmErrorNoError && err != kTvmErrorFramingShortPacket) { TVMPlatformAbort(err); } diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index 620f7552cef6..09f2ccc95d85 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -34,7 +34,7 @@ if(USE_MICRO) # Build an isolated build directory, separate from the TVM tree. list(APPEND CRT_FILE_COPY_JOBS "3rdparty/libcrc/include *.h -> include" - "3rdparty/libcrc/src crcccitt.c -> src/runtime/crt/utvm_rpc_common" + "3rdparty/libcrc/src crcccitt.c -> src/runtime/crt/microtvm_rpc_common" "3rdparty/libcrc/tab gentab_ccitt.inc -> src/runtime/crt/tab" "3rdparty/dlpack/include *.h -> include" "3rdparty/dmlc-core/include *.h -> include" @@ -49,8 +49,8 @@ if(USE_MICRO) "src/runtime/crt/host crt_config.h -> template/host" "src/runtime/crt/host *.cc -> template/host" "src/runtime/crt/memory *.c -> src/runtime/crt/memory" - "src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common" - "src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server" + "src/runtime/crt/microtvm_rpc_common *.cc -> src/runtime/crt/microtvm_rpc_common" + "src/runtime/crt/microtvm_rpc_server *.cc -> src/runtime/crt/microtvm_rpc_server" "src/runtime/minrpc *.h -> src/runtime/minrpc" "src/support generic_arena.h -> src/support" "src/runtime/crt crt_config-template.h -> template" @@ -98,7 +98,7 @@ if(USE_MICRO) set(make_quiet ) endif(${VERBOSE}) - list(APPEND crt_libraries memory graph_executor aot_executor utvm_rpc_server utvm_rpc_common common) # NOTE: listed in link order. + list(APPEND crt_libraries memory graph_executor aot_executor microtvm_rpc_server microtvm_rpc_common common) # NOTE: listed in link order. foreach(crt_lib_name IN LISTS crt_libraries) list(APPEND crt_library_paths "host_standalone_crt/lib${crt_lib_name}.a") endforeach() @@ -166,7 +166,7 @@ if(USE_MICRO) tvm_crt_define_targets() - set(TVM_CRT_LINKER_LIB host_standalone_crt_utvm_rpc_common) + set(TVM_CRT_LINKER_LIB host_standalone_crt_microtvm_rpc_common) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${TVM_CRT_LINKER_LIB} -Wl,--no-whole-archive) elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES ".*Clang") diff --git a/cmake/modules/contrib/ArmComputeLib.cmake b/cmake/modules/contrib/ArmComputeLib.cmake index 54ce917dfb50..9f6b9c1a058e 100644 --- a/cmake/modules/contrib/ArmComputeLib.cmake +++ b/cmake/modules/contrib/ArmComputeLib.cmake @@ -27,6 +27,10 @@ if(USE_ARM_COMPUTE_LIB) if(NOT USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR) list(APPEND COMPILER_SRCS ${ACL_RUNTIME_MODULE}) endif() + if(NOT DEFINED TVM_LLVM_VERSION) + message(FATAL_ERROR "Support for offloading to Compute library for the Arm Architecture requires LLVM Support") + endif() + message(STATUS "Build with Arm Compute Library support...") endif() diff --git a/cmake/modules/contrib/EthosN.cmake b/cmake/modules/contrib/EthosN.cmake index 44d2a2a17ace..6eb5271f91b9 100644 --- a/cmake/modules/contrib/EthosN.cmake +++ b/cmake/modules/contrib/EthosN.cmake @@ -20,6 +20,10 @@ if(NOT USE_ETHOSN STREQUAL "OFF") find_ethosn(${USE_ETHOSN}) + if(NOT DEFINED TVM_LLVM_VERSION) + message(FATAL_ERROR "Support for offloading to Ethos-N requires LLVM Support") + endif() + if(NOT ETHOSN_FOUND) message(FATAL_ERROR "Cannot find Ethos-N, USE_ETHOSN=" ${USE_ETHOSN}) diff --git a/cmake/utils/FindOpenCL.cmake b/cmake/utils/FindOpenCL.cmake index 9b9f8ec94593..c65d46ecab28 100644 --- a/cmake/utils/FindOpenCL.cmake +++ b/cmake/utils/FindOpenCL.cmake @@ -50,7 +50,7 @@ macro(find_opencl use_opencl) if (CMAKE_FIND_ROOT_PATH_MODE_LIBRARY STREQUAL "ONLY") set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) endif() - find_library(OpenCL_LIBRARIES NAMES OpenCL PATHS ${__opencl_sdk}/lib ${__opencl_sdk}/lib64) + find_library(OpenCL_LIBRARIES NAMES OpenCL PATHS ${__opencl_sdk}/lib ${__opencl_sdk}/lib64 ${__opencl_sdk}/lib/x64/) if(OpenCL_LIBRARIES) set(OpenCL_FOUND TRUE) endif() diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index 671ce04e8c1d..9479d7194d3b 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -43,5 +43,5 @@ COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh # Arm(R) Compute Library -COPY install/ubuntu_install_arm_compute_lib.sh /install/ubuntu_install_arm_compute_lib.sh -RUN bash /install/ubuntu_install_arm_compute_lib.sh +COPY install/ubuntu_download_arm_compute_lib_binaries.sh /install/ubuntu_download_arm_compute_lib_binaries.sh +RUN bash /install/ubuntu_download_arm_compute_lib_binaries.sh diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 1ca592f34ab2..7b511deca343 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -87,9 +87,9 @@ RUN bash /install/ubuntu_install_tflite.sh COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh RUN bash /install/ubuntu_install_tensorflow.sh -# Arm(R) Compute Library -COPY install/ubuntu_install_arm_compute_lib.sh /install/ubuntu_install_arm_compute_lib.sh -RUN bash /install/ubuntu_install_arm_compute_lib.sh +# Compute Library +COPY install/ubuntu_download_arm_compute_lib_binaries.sh /install/ubuntu_download_arm_compute_lib_binaries.sh +RUN bash /install/ubuntu_download_arm_compute_lib_binaries.sh # Caffe deps COPY install/ubuntu_install_caffe.sh /install/ubuntu_install_caffe.sh @@ -102,3 +102,10 @@ RUN bash /install/ubuntu_install_ethosn_driver_stack.sh # Vitis-AI PyXIR CI deps COPY install/ubuntu_install_vitis_ai_packages_ci.sh /install/ubuntu_install_vitis_ai_packages_ci.sh RUN bash /install/ubuntu_install_vitis_ai_packages_ci.sh + +# Android SDK +COPY install/ubuntu_install_androidsdk.sh /install/ubuntu_install_androidsdk.sh +RUN bash /install/ubuntu_install_androidsdk.sh +ENV ANDROID_HOME=/opt/android-sdk-linux/ +ENV ANDROID_NDK_HOME=/opt/android-sdk-linux/ndk/21.3.6528147/ + diff --git a/docker/install/ubuntu_install_arm_compute_lib.sh b/docker/install/ubuntu_download_arm_compute_lib_binaries.sh similarity index 53% rename from docker/install/ubuntu_install_arm_compute_lib.sh rename to docker/install/ubuntu_download_arm_compute_lib_binaries.sh index c09bb1290a63..e71cff0f7ba6 100755 --- a/docker/install/ubuntu_install_arm_compute_lib.sh +++ b/docker/install/ubuntu_download_arm_compute_lib_binaries.sh @@ -17,62 +17,47 @@ # under the License. set -e -set -u -set -o pipefail - -repo_url="https://github.com/ARM-software/ComputeLibrary.git" -repo_dir="acl" -install_path="/opt/$repo_dir" -architecture_type=$(uname -i) -target_arch="arm64-v8a" # arm64-v8a / arm64-v8.2-a / armv7a -build_type="native" - -tmpdir=$(mktemp -d) - -cleanup() -{ - rm -rf "$tmpdir" -} - -trap cleanup 0 - -apt-get update && \ -apt-get install -y --no-install-recommends \ - git \ - scons \ - bsdmainutils \ - build-essential # Install cross-compiler when not building natively. # Depending on the architecture selected to compile for, # you may need to install an alternative cross-compiler. if [ "$architecture_type" != "aarch64" ]; then - apt-get install -y --no-install-recommends \ + apt-get update && apt-get install -y --no-install-recommends \ g++-aarch64-linux-gnu \ gcc-aarch64-linux-gnu fi -cd "$tmpdir" +compute_lib_version="v21.05" +compute_lib_base_url="https://github.com/ARM-software/ComputeLibrary/releases/download/${compute_lib_version}" +compute_lib_file_name="arm_compute-${compute_lib_version}-bin-linux.tar.gz" +compute_lib_download_url="${compute_lib_base_url}/${compute_lib_file_name}" -git clone "$repo_url" "$repo_dir" +target_lib="linux-arm64-v8a-neon" -cd "$repo_dir" +# uncomment line below if you need asserts/debug version of the library +# target_lib="${target_lib}-asserts" -# pin version to v21.02 -git checkout "v21.02" +extract_dir="arm_compute-${compute_lib_version}-bin-linux" +install_path="/opt/acl" -if [ "$architecture_type" != "aarch64" ]; then - build_type="cross_compile" -fi +tmpdir=$(mktemp -d) + +cleanup() +{ + rm -rf "$tmpdir" +} + +trap cleanup 0 + +cd "$tmpdir" + +curl -sL "${compute_lib_download_url}" -o "${compute_lib_file_name}" +tar xzf "${compute_lib_file_name}" -scons \ - install_dir="$install_path" \ - Werror=1 \ - -j8 \ - debug=0 \ - asserts=0 \ - neon=1 \ - opencl=0 \ - os=linux \ - arch="$target_arch" \ - build="$build_type" +rm -rf "${install_path}" +mkdir -p "${install_path}" +cp -r "${extract_dir}/include" "${install_path}/" +cp -r "${extract_dir}/arm_compute" "${install_path}/include/" +cp -r "${extract_dir}/support" "${install_path}/include/" +cp -r "${extract_dir}/utils" "${install_path}/include/" +cp -r "${extract_dir}/lib/${target_lib}" "${install_path}/lib" diff --git a/docker/install/ubuntu_init_zephyr_project.sh b/docker/install/ubuntu_init_zephyr_project.sh index 2116a4d981f5..573ff30c38a8 100755 --- a/docker/install/ubuntu_init_zephyr_project.sh +++ b/docker/install/ubuntu_init_zephyr_project.sh @@ -16,10 +16,35 @@ # specific language governing permissions and limitations # under the License. +# +# Initialize Zephyr Project. +# +# Usage: ubuntu_init_zephyr_project.sh path branch [--commit hash] +# path is the installation path for the repository. +# branch is the zephyr branch. +# --commit is the commit hash number of zephyrproject repository. If not specified, it uses the latest commit. +# + +set -x + DOWNLOAD_DIR=$1 -ZEPHYR_BRANCH=$2 +shift +ZEPHYR_BRANCH=$1 +shift + +commit_hash= +if [ "$1" == "--commit" ]; then + shift + commit_hash=$1 +fi west init --mr ${ZEPHYR_BRANCH} ${DOWNLOAD_DIR} + +if [ -n "$commit_hash" ]; then + cd ${DOWNLOAD_DIR}/zephyr + git checkout ${commit_hash} +fi + cd ${DOWNLOAD_DIR} west update west zephyr-export diff --git a/docker/install/ubuntu_install_core.sh b/docker/install/ubuntu_install_core.sh index 6f9d791a650d..2a50afcf5985 100755 --- a/docker/install/ubuntu_install_core.sh +++ b/docker/install/ubuntu_install_core.sh @@ -24,7 +24,7 @@ set -o pipefail apt-get update && apt-get install -y --no-install-recommends \ git make libgtest-dev cmake wget unzip libtinfo-dev libz-dev\ libcurl4-openssl-dev libssl-dev libopenblas-dev g++ sudo \ - apt-transport-https graphviz pkg-config + apt-transport-https graphviz pkg-config curl cd /usr/src/gtest && cmake CMakeLists.txt && make && cp *.a /usr/lib diff --git a/docker/install/ubuntu_install_nodejs.sh b/docker/install/ubuntu_install_nodejs.sh index b36da6295ec0..79c2c3e4b19c 100755 --- a/docker/install/ubuntu_install_nodejs.sh +++ b/docker/install/ubuntu_install_nodejs.sh @@ -21,7 +21,6 @@ set -u set -o pipefail apt-get update -apt-get install -y curl # The node install script fetched and executed here will update the # apt source list, hence the second apt-get update is necessary. diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index a92a0244d707..8f462284c2ba 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -20,9 +20,10 @@ set -e set -u set -o pipefail -# fix to certain version for now -pip3 install onnx==1.6.0 -pip3 install onnxruntime==1.0.0 +# We need to fix the onnx version because changing versions tends to break tests +# TODO(mbrookhart): periodically update +pip3 install onnx==1.8.1 +pip3 install onnxruntime==1.7.0 # torch depends on a number of other packages, but unhelpfully, does # not expose that in the wheel!!! diff --git a/docker/install/ubuntu_install_qemu.sh b/docker/install/ubuntu_install_qemu.sh index ad4037a34187..b1d375253e05 100755 --- a/docker/install/ubuntu_install_qemu.sh +++ b/docker/install/ubuntu_install_qemu.sh @@ -26,6 +26,13 @@ set -e set -o pipefail +QEMU_NAME=qemu-5.1.0 +QEMU_SIG_FILE=${QEMU_NAME}.tar.xz.sig +QEMU_TAR_FILE=${QEMU_NAME}.tar.xz + +# Clean previous build +rm -rf ${QEMU_NAME} ${QEMU_SIG_FILE} ${QEMU_TAR_FILE} + # Get number of cores for build if [ -n "${TVM_CI_NUM_CORES}" ]; then num_cores=${TVM_CI_NUM_CORES} @@ -62,10 +69,11 @@ p5ez/+2k4VAIwIQoP5DoO06waLBffvLIAdPPKYsx71K67OoGG2svc7duC/+5qf1x -----END PGP ARMORED FILE----- EOF curl -OLs https://download.qemu.org/qemu-5.1.0.tar.xz -gpg --verify qemu-5.1.0.tar.xz.sig +gpg --verify ${QEMU_SIG_FILE} + +tar -xf ${QEMU_TAR_FILE} -tar -xf qemu-5.1.0.tar.xz -cd qemu-5.1.0 +cd ${QEMU_NAME} ./configure --target-list=${target_list} make -j${num_cores} sudo make install diff --git a/docker/install/ubuntu_install_redis.sh b/docker/install/ubuntu_install_redis.sh index 21dab1f3a57b..0eb46eb8edec 100755 --- a/docker/install/ubuntu_install_redis.sh +++ b/docker/install/ubuntu_install_redis.sh @@ -21,4 +21,4 @@ set -u set -o pipefail apt-get update && apt-get install -y redis-server -pip3 install xgboost>=1.1.0 psutil +pip3 install "xgboost>=1.1.0" psutil diff --git a/docker/install/ubuntu_install_rust.sh b/docker/install/ubuntu_install_rust.sh index 5716b11db6c4..c9f06e8e982e 100755 --- a/docker/install/ubuntu_install_rust.sh +++ b/docker/install/ubuntu_install_rust.sh @@ -20,7 +20,6 @@ set -e set -u set -o pipefail -apt-get update && apt-get install -y --no-install-recommends curl export RUSTUP_HOME=/opt/rust export CARGO_HOME=/opt/rust diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 33757a0d4d57..80b3323d5956 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -20,4 +20,5 @@ set -e set -u set -o pipefail -pip3 install sphinx sphinx-gallery==0.4.0 autodocsumm sphinx_rtd_theme sphinx_autodoc_annotation matplotlib Image "commonmark>=0.7.3" "docutils>=0.11" +# NOTE: install docutils < 0.17 to work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 +pip3 install sphinx sphinx-gallery==0.4.0 autodocsumm sphinx_rtd_theme sphinx_autodoc_annotation matplotlib Image "commonmark>=0.7.3" "docutils>=0.11" "docutils<0.17" diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 286a086abd82..81802964ba0e 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -20,7 +20,4 @@ set -e set -u set -o pipefail -# h5py is pinned to minor than 3 due to issues with -# tensorflow: -# https://github.com/tensorflow/tensorflow/issues/44467 -pip3 install tensorflow==2.3.1 keras==2.4.3 "h5py<3.0" +pip3 install tensorflow==2.4.2 diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index 2dfbb0681a80..cfaa643d50f3 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -20,6 +20,10 @@ set -e set -u set -o pipefail +# The tflite version should have matched versions to the tensorflow +# version installed from pip in ubuntu_install_tensorflow.sh +TENSORFLOW_VERSION=$(python3 -c "import tensorflow; print(tensorflow.__version__)" 2> /dev/null) + # Download, build and install flatbuffers git clone --branch=v1.12.0 --depth=1 --recursive https://github.com/google/flatbuffers.git cd flatbuffers @@ -33,14 +37,14 @@ pip3 install flatbuffers # Build the TFLite static library, necessary for building with TFLite ON. # The library is built at: # tensorflow/tensorflow/lite/tools/make/gen/*/lib/libtensorflow-lite.a. -git clone https://github.com/tensorflow/tensorflow --branch=r2.3 +git clone https://github.com/tensorflow/tensorflow --branch=v${TENSORFLOW_VERSION} ./tensorflow/tensorflow/lite/tools/make/download_dependencies.sh ./tensorflow/tensorflow/lite/tools/make/build_lib.sh # Setup tflite from schema mkdir tflite +cp tensorflow/tensorflow/lite/schema/schema.fbs tflite cd tflite -wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r2.3/tensorflow/lite/schema/schema.fbs flatc --python schema.fbs cat <setup.py @@ -48,7 +52,7 @@ import setuptools setuptools.setup( name="tflite", - version="2.3.1", + version="${TENSORFLOW_VERSION}", author="google", author_email="google@google.com", description="TFLite", diff --git a/docs/README.txt b/docs/README.txt index 1da0c833a256..06d7a0f6e444 100644 --- a/docs/README.txt +++ b/docs/README.txt @@ -3,7 +3,7 @@ TVM Documentations This folder contains the source of TVM documents - A hosted version of doc is at https://tvm.apache.org/docs -- pip install "sphinx>=1.5.5" sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark "Pillow<7" "autodocsumm<0.2.0" tlcpack-sphinx-addon +- pip install "sphinx>=1.5.5" sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark "Pillow<7" "autodocsumm<0.2.0" tlcpack-sphinx-addon "docutils<0.17" - (Versions 0.2.0 to 0.2.2 of autodocsumm are incompatible with sphinx>=3.4, https://github.com/Chilipp/autodocsumm/pull/42 ) - Build tvm first in the root folder. - Run the following command @@ -56,4 +56,4 @@ You will need a gpu CI environment. Define the Order of Tutorials ----------------------------- You can define the order of tutorials with `conf.py::subsection_order` and `conf.py::within_subsection_order`. -By default, the tutorials within one subsection is sorted by filename. \ No newline at end of file +By default, the tutorials within one subsection is sorted by filename. diff --git a/docs/deploy/arm_compute_lib.rst b/docs/deploy/arm_compute_lib.rst index 1abc31bbb422..6fb531a0a8f6 100644 --- a/docs/deploy/arm_compute_lib.rst +++ b/docs/deploy/arm_compute_lib.rst @@ -15,8 +15,8 @@ specific language governing permissions and limitations under the License. -Relay Arm:sup:`®` Compute Library Integration -============================================== +Relay Arm\ :sup:`®` Compute Library Integration +=============================================== **Author**: `Luke Hutton `_ Introduction diff --git a/docs/index.rst b/docs/index.rst index e3cf466d3cf1..a7ae68c87b01 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,6 +44,7 @@ For Developers contribute/index deploy/index dev/how_to + microtvm/index errors faq @@ -76,7 +77,6 @@ For Developers :hidden: :caption: MISC - microtvm/index vta/index diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 257fe085bfe5..49d3a42d3e98 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -80,6 +80,17 @@ Here is another example to match an op with a specific attribute: y = relay.var('y') assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) +Or a convolution with a specific kernel size: + +.. code-block:: python + + def test_match_kernel_size(): + is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]}) + x = relay.var('x') + y = relay.var('y') + assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3])) + + Matching an Optional Op *********************** diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 71a69a000944..418d532fdd5f 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -42,17 +43,68 @@ #include namespace tvm { + +/*! + * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) + * \param mod The IRmodule to lower + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); + +/*! + * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list + * defined in CreatePassList) + * \param func The PrimFunc to lower + * \param name The name of the lowered function. + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, + bool simple_mode = false); + /*! - * \brief Build an IRModule given a schedule, args and binds - * \param sch The schedule to lower. + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. + * \param sch The TE schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ -TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds); +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + +/*! + * \brief Build an IRModule given a TE schedule, args and binds. This function also applies + * the lowering passes defined in CreatePassList. + * \param sch The TE schedule to lower. + * \param args The arguments to the function (Array of Tensor, Buffer and Vars) + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param simple_mode Disables the loop partition pass. Defaults to false. + * \return The result module. + */ +TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, + const std::string& name, + const std::unordered_map& binds, + bool simple_mode = false); + +/*! + * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want + * to apply lowering passes as well, use LowerSchedule. + * \param sch The schedule + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \return The result module. + */ +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds); /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index a18d42902503..683170026451 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -146,7 +146,7 @@ class OpNode : public RelayExprNode { // Internal function to compute if it is primitive op bool IsPrimitiveOp_() const { const auto& fn_ty = this->op_type; - ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << "is not registered"; + ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered"; if (fn_ty->type_constraints.size() != 1) return false; const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); if (rel == nullptr) return false; diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 6a7f3f3190a0..b53ca2921fe7 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -45,6 +45,8 @@ class SourceNameNode : public Object { // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + static constexpr bool _type_has_method_sequal_reduce = true; + bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const { return equal(name, other->name); } @@ -98,6 +100,7 @@ class SpanNode : public Object { v->Visit("end_line", &end_line); v->Visit("end_column", &end_column); } + static constexpr bool _type_has_method_sequal_reduce = true; bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { return equal(source_name, other->source_name) && equal(line, other->line) && diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 69a9c64a4588..a8317e1e51ad 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -148,7 +148,7 @@ struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; Optional index_rank; - TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { + TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); TVM_ATTR_FIELD(index_rank) .set_default(NullValue()) diff --git a/include/tvm/runtime/crt/utvm_rpc_server.h b/include/tvm/runtime/crt/microtvm_rpc_server.h similarity index 76% rename from include/tvm/runtime/crt/utvm_rpc_server.h rename to include/tvm/runtime/crt/microtvm_rpc_server.h index b4fb2b8ad03b..9a7ed54ffe95 100644 --- a/include/tvm/runtime/crt/utvm_rpc_server.h +++ b/include/tvm/runtime/crt/microtvm_rpc_server.h @@ -18,12 +18,12 @@ */ /*! - * \file utvm_rpc_server.h + * \file microtvm_rpc_server.h * \brief MicroTVM RPC Server */ -#ifndef TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_ -#define TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_ +#ifndef TVM_RUNTIME_CRT_MICROTVM_RPC_SERVER_H_ +#define TVM_RUNTIME_CRT_MICROTVM_RPC_SERVER_H_ #include #include @@ -40,14 +40,15 @@ extern "C" { * \param num_bytes Number of bytes avaiable in data. * \return The number of bytes written. */ -typedef ssize_t (*utvm_rpc_channel_write_t)(void* context, const uint8_t* data, size_t num_bytes); +typedef ssize_t (*microtvm_rpc_channel_write_t)(void* context, const uint8_t* data, + size_t num_bytes); /*! \brief Opaque pointer type to TVM RPC Server. */ -typedef void* utvm_rpc_server_t; +typedef void* microtvm_rpc_server_t; /*! \brief Initialize the TVM RPC Server. * - * Call this on device startup before calling anyother utvm_rpc_server_ functions. + * Call this on device startup before calling anyother microtvm_rpc_server_ functions. * * \param write_func A callback function invoked by the TVM RPC Server to write data back to the * host. Internally, the TVM RPC Server will block until all data in a reply @@ -56,7 +57,8 @@ typedef void* utvm_rpc_server_t; * \return A pointer to the TVM RPC Server. The pointer is allocated in the same memory space as * the TVM workspace. */ -utvm_rpc_server_t UTvmRpcServerInit(utvm_rpc_channel_write_t write_func, void* write_func_ctx); +microtvm_rpc_server_t MicroTVMRpcServerInit(microtvm_rpc_channel_write_t write_func, + void* write_func_ctx); /*! \brief Do any tasks suitable for the main thread, and maybe process new incoming data. * @@ -67,11 +69,11 @@ utvm_rpc_server_t UTvmRpcServerInit(utvm_rpc_channel_write_t write_func, void* w * updated to the number of unprocessed bytes remaining in `new_data` (usually 0). * \return An error code indicating the outcome of the server main loop iteration. */ -tvm_crt_error_t UTvmRpcServerLoop(utvm_rpc_server_t server, uint8_t** new_data, - size_t* new_data_size_bytes); +tvm_crt_error_t MicroTVMRpcServerLoop(microtvm_rpc_server_t server, uint8_t** new_data, + size_t* new_data_size_bytes); #ifdef __cplusplus } #endif -#endif // TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_ +#endif // TVM_RUNTIME_CRT_MICROTVM_RPC_SERVER_H_ diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index b4fdcbff58b4..3b767547357b 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -389,4 +389,19 @@ inline DLDataType String2DLDataType(std::string s) { using DataType = runtime::DataType; } // namespace tvm + +namespace std { +template <> +struct hash { + inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; } + std::size_t operator()(tvm::DataType const& dtype) const { + int a = dtype.code(); + int b = dtype.bits(); + int c = dtype.lanes(); + int d = cantor_pairing_function(a, b); + return cantor_pairing_function(c, d); + } +}; +} // namespace std + #endif // TVM_RUNTIME_DATA_TYPE_H_ diff --git a/include/tvm/runtime/micro/standalone/utvm_runtime.h b/include/tvm/runtime/micro/standalone/microtvm_runtime.h similarity index 54% rename from include/tvm/runtime/micro/standalone/utvm_runtime.h rename to include/tvm/runtime/micro/standalone/microtvm_runtime.h index ef6cd4023dba..827d91f62076 100644 --- a/include/tvm/runtime/micro/standalone/utvm_runtime.h +++ b/include/tvm/runtime/micro/standalone/microtvm_runtime.h @@ -17,28 +17,29 @@ * under the License. */ -#ifndef TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_H_ -#define TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_H_ +#ifndef TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_RUNTIME_H_ +#define TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_RUNTIME_H_ #include #include #define TVM_MICRO_RUNTIME_API_API extern "C" __attribute__((visibility("default"))) -TVM_MICRO_RUNTIME_API_API void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module); +TVM_MICRO_RUNTIME_API_API void* MicroTVMRuntimeCreate(const char* json, size_t json_len, + void* module); -TVM_MICRO_RUNTIME_API_API void UTVMRuntimeDestroy(void* handle); +TVM_MICRO_RUNTIME_API_API void MicroTVMRuntimeDestroy(void* handle); -TVM_MICRO_RUNTIME_API_API void UTVMRuntimeSetInput(void* handle, int index, void* tensor); +TVM_MICRO_RUNTIME_API_API void MicroTVMRuntimeSetInput(void* handle, int index, void* tensor); -TVM_MICRO_RUNTIME_API_API void UTVMRuntimeRun(void* handle); +TVM_MICRO_RUNTIME_API_API void MicroTVMRuntimeRun(void* handle); -TVM_MICRO_RUNTIME_API_API void UTVMRuntimeGetOutput(void* handle, int index, void* tensor); +TVM_MICRO_RUNTIME_API_API void MicroTVMRuntimeGetOutput(void* handle, int index, void* tensor); -TVM_MICRO_RUNTIME_API_API void* UTVMRuntimeDSOModuleCreate(const char* so, size_t so_len); +TVM_MICRO_RUNTIME_API_API void* MicroTVMRuntimeDSOModuleCreate(const char* so, size_t so_len); -TVM_MICRO_RUNTIME_API_API void UTVMRuntimeDSOModuleDestroy(void* module); +TVM_MICRO_RUNTIME_API_API void MicroTVMRuntimeDSOModuleDestroy(void* module); #undef TVM_MICRO_RUNTIME_API_API -#endif // TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_H_ +#endif // TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_RUNTIME_H_ diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index 32e74f6ef9d5..0ba7421ce409 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -88,18 +88,6 @@ bool VerifyCompactBuffer(const Stmt& stmt); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); -/*! - * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. - * - * \param stmt The stmt to be trasnformed. - * \param schedule The original schedule. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \return Transformed stmt. - */ -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer); - /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create * a PrimFunc that can then be used for further TIR optimizations. diff --git a/include/tvm/topi/cuda/reduction.h b/include/tvm/topi/cuda/reduction.h index 7160419422a6..51f35ed8dc25 100644 --- a/include/tvm/topi/cuda/reduction.h +++ b/include/tvm/topi/cuda/reduction.h @@ -70,7 +70,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, if (out_stage->op.as()->axis.size() > 0) { all_reduce = false; num_thread = 32; - if (target->kind->name == "opencl") { + if (target->kind->name == "opencl" || target->kind->name == "metal") { // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests. // Don't know why. num_thread = 16; diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 00dd4643190e..dc338a3fcd3b 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -201,7 +201,10 @@ ("coremltools", None), ("cpplint", None), ("decorator", None), - ("docutils", None), + ( + "docutils", + "<0.17", + ), # Work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 ("future", None), ("image", None), ("matplotlib", None), diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 099502d17d78..0d18bc08e5ed 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -46,7 +46,7 @@ logger = logging.getLogger("auto_scheduler") -def call_all_topi_funcs(mod, params, target): +def call_all_topi_funcs(mod, params, target, opt_level=3): """Call all TOPI compute to extract auto_scheduler tasks in a Relay program""" # pylint: disable=import-outside-toplevel from tvm import relay @@ -57,7 +57,7 @@ def call_all_topi_funcs(mod, params, target): autotvm.GLOBAL_SCOPE.silent = True with transform.PassContext( - opt_level=3, + opt_level=opt_level, config={ "relay.backend.use_auto_scheduler": True, "relay.backend.disable_compile_engine_cache": True, @@ -91,7 +91,13 @@ def call_all_topi_funcs(mod, params, target): def extract_tasks( - mod, params, target, target_host=None, hardware_params=None, include_simple_tasks=False + mod, + params, + target, + target_host=None, + hardware_params=None, + include_simple_tasks=False, + opt_level=3, ): """Extract tuning tasks from a relay program. @@ -109,6 +115,8 @@ def extract_tasks( Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. + opt_level : Optional[int] + The optimization level of the task extractions. Returns ------- @@ -132,7 +140,9 @@ def extract_tasks( with env: # Wrap build call in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool - build_thread = threading.Thread(target=call_all_topi_funcs, args=(mod, params, target)) + build_thread = threading.Thread( + target=call_all_topi_funcs, args=(mod, params, target, opt_level) + ) build_thread.start() build_thread.join() dispatch_ctx.verbose = old_verbose diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index dff0f098d84a..8d2591dce50b 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ - binds, _ = build_module.get_binds(args, binds) + binds, _ = build_module.get_binds(args, compact=False, binds=binds) sch = sch.normalize() # Phase 0 bounds = schedule.InferBound(sch) diff --git a/python/tvm/autotvm/task/code_hash.py b/python/tvm/autotvm/task/code_hash.py index 3331fc13c719..2bd053da7244 100644 --- a/python/tvm/autotvm/task/code_hash.py +++ b/python/tvm/autotvm/task/code_hash.py @@ -19,6 +19,7 @@ code hashing is used to check the consistence of schedule code and the parameters loaded from log """ +import functools import inspect import zlib @@ -35,6 +36,7 @@ def attach_code_hash(s): """ def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) raw_hash = zlib.crc32("".join(inspect.getsourcelines(func)[0]).encode()) @@ -56,6 +58,7 @@ def attach_code_hash_to_arg(arg_idx=1): """ def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): func(*args, **kwargs) assert isinstance(args[arg_idx], schedule.Schedule) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 668832b8a86c..1f5827d7e9d0 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -21,6 +21,8 @@ func is a state-less function, or a string that registers the standard task. """ +import functools + import numpy as np from tvm import runtime @@ -411,6 +413,7 @@ def matmul(N, L, M, dtype): """ def _decorate(f): + @functools.wraps(f) def wrapper(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" workload = args_to_workload(args, task_name) diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 2558c7669ac9..32d8674640ed 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -26,6 +26,8 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ +import functools + import tvm.te._ffi_api from tvm.target import Target from tvm.te import tensor @@ -149,6 +151,7 @@ def register_topi_compute(task_name, func=None): """ def _decorate(topi_compute): + @functools.wraps(topi_compute) @_register_task_compute(task_name) def wrapper(*args, **kwargs): """wrapper function for topi compute""" @@ -224,6 +227,7 @@ def register_topi_schedule(task_name, func=None): """ def _decorate(topi_schedule): + @functools.wraps(topi_schedule) @_register_task_schedule(task_name) def wrapper(outs, *args, **kwargs): """wrapper function for topi schedule""" diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index 58bf933d44b8..1dfeb801b370 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -72,7 +72,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs C: Tensor The result tensor. """ - b = lhs.shape[0] + b = te.max(lhs.shape[0], rhs.shape[0]) n = lhs.shape[2] if transa else lhs.shape[1] m = rhs.shape[1] if transb else rhs.shape[2] return te.extern( diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 0e22e0c09274..0ca3c3d6d423 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -64,6 +64,24 @@ _ALGO_TYPE = ["fwd", "bwd_filter", "bwd_data"] +def exists(): + """ + Checks whether the local machine can use CuDNN. + + Returns + ------- + exists: bool + + True if CuDNN support is enabled and a CuDNN-capable GPU + exists. Otherwise, False. + """ + func = tvm.get_global_func("tvm.contrib.cudnn.exists", allow_missing=True) + if func is None: + return False + + return bool(func()) + + def algo_to_index(algo_type, algo_name): """Return a index represents the algorithm, which can be used in calling CuDNN function @@ -209,6 +227,101 @@ def conv_output_shape( oshape: list output shape """ + + assert len(x_shape) == len(w_shape) + assert len(x_shape) in (4, 5) + + if tensor_format == 0: + n_output = x_shape[0] + c_output = w_shape[0] + x_chan = x_shape[1] + w_chan_input = w_shape[1] + x_shape = x_shape[2:] + w_shape = w_shape[2:] + + elif tensor_format == 1: + n_output = x_shape[0] + c_output = w_shape[0] + x_chan = x_shape[-1] + w_chan_input = w_shape[-1] + assert len(x_shape) == 4, "CuDNN layout NHWC is only well-defined for 4d tensors" + x_shape = x_shape[1:-1] + w_shape = w_shape[1:-1] + + elif tensor_format == 2: + n_output = x_shape[0] + c_output = w_shape[0] + x_chan = x_shape[1] + w_chan_input = w_shape[1] + w_lanes = tvm.runtime.DataType(conv_dtype).lanes + assert w_lanes == 1 + x_shape = x_shape[2:] + w_shape = w_shape[2:] + + else: + raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + + x_lanes = tvm.runtime.DataType(data_dtype).lanes + assert x_chan * x_lanes == w_chan_input * groups, ( + "Mismatched dimensions, data has {} channels/group " + "(dimension {} with {} lanes/value, {} groups), " + "but weights require {} input channels/group" + ).format(x_chan // groups, x_chan, x_lanes, groups, w_chan_input) + + output_dims = [] + for x_shape_i, w_shape_i, pad_i, stride_i, dilation_i in zip( + x_shape, w_shape, pad, stride, dilation + ): + output_dim = 1 + (x_shape_i + 2 * pad_i - (((w_shape_i - 1) * dilation_i) + 1)) // stride_i + output_dims.append(output_dim) + + if tensor_format in [0, 2]: + output = [n_output, c_output, *output_dims] + elif tensor_format == 1: + output = [n_output, *output_dims, c_output] + else: + raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + + return output + + +def _conv_output_shape_from_cudnn( + tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1 +): + """Get output shape of 2D or 3D convolution. The output of this + function should be identical to that of conv_output_shape, but + requires a GPU with CuDNN to be present. This is maintained for + testing purposes to validate the output of conv_output_shape. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + x_shape: list + input shape + w_shape: list + weight shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + oshape: list + output shape + + """ dims = len(x_shape) assert dims in (4, 5) @@ -217,7 +330,7 @@ def conv_output_shape( ) oshape = np.zeros((dims), dtype=np.int32) - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape") + func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn") func( tensor_format, dims - 2, diff --git a/python/tvm/contrib/mkl.py b/python/tvm/contrib/mkl.py index c6e340619ef8..449d660c9027 100644 --- a/python/tvm/contrib/mkl.py +++ b/python/tvm/contrib/mkl.py @@ -105,7 +105,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs C: Tensor The result tensor. """ - b = lhs.shape[0] + b = te.max(lhs.shape[0], rhs.shape[0]) n = lhs.shape[2] if transa else lhs.shape[1] m = rhs.shape[1] if transb else rhs.shape[2] return te.extern( diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py new file mode 100644 index 000000000000..c423656d78f5 --- /dev/null +++ b/python/tvm/driver/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.driver""" +import tvm._ffi + +tvm._ffi._init_api("driver", __name__) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a3d0bb656736..a4df63f225b2 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -37,96 +37,58 @@ from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from . import _ffi_api as ffi + def get_binds(args, compact=False, binds=None): """Internal function to get binds and arg_list given arguments. - Parameters ---------- args : list of Buffer or Tensor or Var The argument lists to the function. - compact : bool If the statement has already bound to a compact buffer. - binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. - Returns ------- binds: dict The bind specification - arg_list: list The list of symbolic buffers of arguments. """ - binds = {} if binds is None else binds.copy() - arg_list = [] - for x in args: - if isinstance(x, tensor.Tensor): - any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) - buffer_type = "auto_broadcast" if any_dim and not compact else "" - if x not in binds: - buf = tvm.tir.decl_buffer( - x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type - ) - binds[x] = buf - arg_list.append(buf) - else: - arg_list.append(binds[x]) - elif isinstance(x, schedule.Buffer): - arg_list.append(x) - elif isinstance(x, tvm.tir.Var): - arg_list.append(x) - else: - raise ValueError("args must be Tensor, Buffer or Var") + binds, arg_list = ffi.get_binds(args, compact, binds) return binds, arg_list -def form_irmodule(sch, args, name, binds): +def schedule_to_module( + sch: schedule.Schedule, + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: """According to the given schedule, form a function. - Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body - args : list of Buffer or Tensor or Var The argument lists to the function. - name : str - The name of result function. - + The name of result function, default name is "main" binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information - Returns ------- The body formed according to the given schedule """ - # normalize schedule first - pass_ctx = PassContext.current() - sch = sch.normalize() - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - - compact = schedule.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) - - stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) - func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) - - func = func.with_attr("global_symbol", name) - - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - return tvm.IRModule({name: func}) + return ffi.schedule_to_module(sch, args, name, binds) def lower( - inputs: Union[schedule.Schedule, PrimFunc, IRModule], + inp: Union[schedule.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -136,7 +98,7 @@ def lower( Parameters ---------- - input : Union[schedule.Schedule, PrimFunc, IRModule] + inputs : Union[schedule.Schedule, PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] @@ -160,90 +122,13 @@ def lower( m : IRModule The result IRModule """ - # config setup - pass_ctx = PassContext.current() - instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) - disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) - add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) - - lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] - lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] - lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] - lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - - # Phase 0 - pass_list = lower_phase0 - is_legacy_te_schedule: bool = False - - if isinstance(inputs, schedule.Schedule): - if args is None: - raise ValueError("args must be given for lowering from TE schedule") - mod = form_irmodule(inputs, args, name, binds) - is_legacy_te_schedule = True - elif isinstance(inputs, PrimFunc): - func = inputs.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) - elif isinstance(inputs, IRModule): - mod = inputs - else: - raise TypeError( - f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}" - ) - - # Phase 1 - if is_legacy_te_schedule: - pass_list += [ - tvm.tir.transform.InjectPrefetch(), - tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), - ] - else: - pass_list += [ - tvm.tir.transform.LowerInitBlock(), - tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), - tvm.tir.transform.ConvertBlocksToOpaque(), - tvm.tir.transform.CompactBufferAllocation(), - tvm.tir.transform.FlattenBuffer(), - ] - pass_list += [ - tvm.tir.transform.BF16Legalize(), - tvm.tir.transform.NarrowDataType(32), - tvm.tir.transform.Simplify(), - ] - - pass_list += lower_phase1 - - # Phase 2 - if not simple_mode: - pass_list += [(tvm.tir.transform.LoopPartition())] - - pass_list += [ - tvm.tir.transform.VectorizeLoop(not disable_vectorize), - tvm.tir.transform.InjectVirtualThread(), - tvm.tir.transform.InjectDoubleBuffer(), - tvm.tir.transform.StorageRewrite(), - tvm.tir.transform.UnrollLoop(), - ] - pass_list += lower_phase2 - - # Phase 3 - pass_list += [ - tvm.tir.transform.Simplify(), - tvm.tir.transform.RemoveNoOp(), - ] - - pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] - pass_list += [tvm.tir.transform.HoistIfThenElse()] - pass_list += lower_phase3 - - # Instrument BoundCheckers - if instrument_bound_checkers: - pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - - optimize = tvm.transform.Sequential(pass_list) - mod = optimize(mod) - return mod + if isinstance(inp, IRModule): + return ffi.lower_module(inp, simple_mode) + if isinstance(inp, PrimFunc): + return ffi.lower_primfunc(inp, name, simple_mode) + if isinstance(inp, schedule.Schedule): + return ffi.lower_schedule(inp, args, name, binds, simple_mode) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def _build_for_device(input_mod, target, target_host): diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 48e18fb6b6ad..033522d0e81a 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -415,3 +415,103 @@ def parse_shape_string(inputs_string): shape_dict[name] = shape return shape_dict + + +def get_pass_config_value(name, value, config_type): + """Get a PassContext configuration value, based on its config data type. + + Parameters + ---------- + name: str + config identifier name. + value: str + value assigned to the config, provided via command line. + config_type: str + data type defined to the config, as string. + + Returns + ------- + parsed_value: bool, int or str + a representation of the input value, converted to the type + specified by config_type. + """ + + if config_type == "IntImm": + # "Bool" configurations in the PassContext are recognized as + # IntImm, so deal with this case here + mapping_values = { + "false": False, + "true": True, + } + + if value.isdigit(): + parsed_value = int(value) + else: + # if not an int, accept only values on the mapping table, case insensitive + parsed_value = mapping_values.get(value.lower(), None) + + if parsed_value is None: + raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ") + + if config_type == "runtime.String": + parsed_value = value + + return parsed_value + + +def parse_configs(input_configs): + """Parse configuration values set via command line. + + Parameters + ---------- + input_configs: list of str + list of configurations provided via command line. + + Returns + ------- + pass_context_configs: dict + a dict containing key-value configs to be used in the PassContext. + """ + if not input_configs: + return {} + + all_configs = tvm.ir.transform.PassContext.list_configs() + supported_config_types = ("IntImm", "runtime.String") + supported_configs = [ + name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types + ] + + pass_context_configs = {} + + for config in input_configs: + if not config: + raise TVMCException( + f"Invalid format for configuration '{config}', use =" + ) + + # Each config is expected to be provided as "name=value" + try: + name, value = config.split("=") + name = name.strip() + value = value.strip() + except ValueError: + raise TVMCException( + f"Invalid format for configuration '{config}', use =" + ) + + if name not in all_configs: + raise TVMCException( + f"Configuration '{name}' is not defined in TVM. " + f"These are the existing configurations: {', '.join(all_configs)}" + ) + + if name not in supported_configs: + raise TVMCException( + f"Configuration '{name}' uses a data type not supported by TVMC. " + f"The following configurations are supported: {', '.join(supported_configs)}" + ) + + parsed_value = get_pass_config_value(name, value, all_configs[name]["type"]) + pass_context_configs[name] = parsed_value + + return pass_context_configs diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 071474a31594..e79a07f1de2e 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -81,7 +81,15 @@ def add_compile_parser(subparsers): choices=["so", "mlf"], default="so", help="output format. Use 'so' for shared object or 'mlf' for Model Library Format " - "(only for µTVM targets). Defaults to 'so'.", + "(only for microTVM targets). Defaults to 'so'.", + ) + parser.add_argument( + "--pass-config", + action="append", + metavar=("name=value"), + help="configurations to be used at compile time. This option can be provided multiple " + "times, each one to set one configuration value, " + "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", ) parser.add_argument( "--target", @@ -145,6 +153,7 @@ def drive_compile(args): target_host=None, desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, + pass_context_configs=args.pass_config, ) return 0 @@ -162,6 +171,7 @@ def compile_model( target_host: Optional[str] = None, desired_layout: Optional[str] = None, disabled_pass: Optional[str] = None, + pass_context_configs: Optional[str] = None, ): """Compile a model from a supported framework into a TVM module. @@ -202,6 +212,9 @@ def compile_model( disabled_pass: str, optional Comma-separated list of passes which needs to be disabled during compilation + pass_context_configs: str, optional + String containing a set of configurations to be passed to the + PassContext. Returns @@ -212,7 +225,7 @@ def compile_model( """ mod, params = tvmc_model.mod, tvmc_model.params - config = {} + config = common.parse_configs(pass_context_configs) if desired_layout: mod = common.convert_graph_layout(mod, desired_layout) diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 191f8616c405..fec8224ceb17 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -19,7 +19,7 @@ """ import json import logging -from typing import Optional, Dict, List, Union +from typing import Dict, List, Optional, Union import numpy as np import tvm @@ -30,12 +30,11 @@ from tvm.relay.param_dict import load_param_dict from . import common -from .model import TVMCPackage, TVMCResult from .common import TVMCException from .main import register_parser +from .model import TVMCPackage, TVMCResult from .result_utils import get_top_results - # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -51,7 +50,7 @@ def add_run_parser(subparsers): # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl"], + choices=["cpu", "cuda", "cl", "metal"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -360,11 +359,11 @@ def run_module( ) # Currently only two package formats are supported: "classic" and - # "mlf". The later can only be used for micro targets, i.e. with µTVM. + # "mlf". The later can only be used for micro targets, i.e. with microTVM. if tvmc_package.type == "mlf": raise TVMCException( "You're trying to run a model saved using the Model Library Format (MLF)." - "MLF can only be used to run micro targets (µTVM)." + "MLF can only be used to run micro targets (microTVM)." ) if hostname: @@ -391,6 +390,8 @@ def run_module( dev = session.cuda() elif device == "cl": dev = session.cl() + elif device == "metal": + dev = session.metal() else: assert device == "cpu" dev = session.cpu() diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 910b0ce1721f..694aebe6f1ed 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -111,7 +111,7 @@ def get_runtime_libs(executor: str) -> str: source (i.e. not header) files. """ if executor == "host-driven": - crt_runtime_lib_names = ["utvm_rpc_server", "utvm_rpc_common", "common"] + crt_runtime_lib_names = ["microtvm_rpc_server", "microtvm_rpc_common", "common"] elif executor == "aot": crt_runtime_lib_names = ["aot_executor", "common"] else: diff --git a/python/tvm/micro/contrib/zephyr.py b/python/tvm/micro/contrib/zephyr.py index b7d7496b7440..3c79c200d155 100644 --- a/python/tvm/micro/contrib/zephyr.py +++ b/python/tvm/micro/contrib/zephyr.py @@ -30,6 +30,9 @@ import shutil import subprocess import sys +import threading +import queue +import enum import yaml @@ -111,9 +114,9 @@ def __init__( self._qemu = "qemu" in board # For Zephyr boards that run emulated by default but don't have the prefix "qemu_" in their - # board names, a suffix "-qemu" is added by users of µTVM when specifying the board name to - # inform that the QEMU transporter must be used just like for the boards with the prefix. - # Zephyr does not recognize the suffix, so we trim it off before passing it. + # board names, a suffix "-qemu" is added by users of microTVM when specifying the board + # name to inform that the QEMU transporter must be used just like for the boards with + # the prefix. Zephyr does not recognize the suffix, so we trim it off before passing it. if "-qemu" in board: board = board.replace("-qemu", "") @@ -172,6 +175,17 @@ def library(self, output, sources, options=None): project_dir_conf = os.path.join(self._project_dir, "prj.conf") if os.path.exists(project_dir_conf): shutil.copy(project_dir_conf, lib_prj_conf) + + # Copy board-specific Zephyr config file from the project_dir to + # the build lib dir so board-specific configs can be found and used by + # Zephyr's build system in conjunction with the generic prj.conf configs. + board_conf = os.path.join("boards", self._board + ".conf") + project_dir_board_conf = os.path.join(self._project_dir, board_conf) + if os.path.exists(project_dir_board_conf): + os.mkdir(os.path.join(output, "boards")) + lib_dir_board_conf = os.path.join(output, board_conf) + shutil.copy(project_dir_board_conf, lib_dir_board_conf) + else: with open(lib_prj_conf, "w") as prj_conf_f: prj_conf_f.write("CONFIG_CPLUSPLUS=y\n") @@ -619,6 +633,12 @@ def write(self, data, timeout_sec): return num_written +class ZephyrQemuMakeResult(enum.Enum): + QEMU_STARTED = "qemu_started" + MAKE_FAILED = "make_failed" + EOF = "eof" + + class ZephyrQemuTransport(Transport): """The user-facing Zephyr QEMU transport class.""" @@ -630,6 +650,7 @@ def __init__(self, base_dir, startup_timeout_sec=5.0, qemu_debugger=None, **kwar self.fd_transport = None self.pipe_dir = None self.qemu_debugger = qemu_debugger + self._queue = queue.Queue() def timeouts(self): return TransportTimeouts( @@ -658,7 +679,12 @@ def open(self): ["make", "run", f"QEMU_PIPE={self.pipe}"], cwd=self.base_dir, **self.kwargs, + stdout=subprocess.PIPE, ) + try: + self._wait_for_qemu() + except Exception as error: + raise error if self.qemu_debugger is not None: self.qemu_debugger.start() @@ -703,6 +729,35 @@ def write(self, data, timeout_sec): raise TransportClosedError() return self.fd_transport.write(data, timeout_sec) + def _qemu_check_stdout(self): + for line in self.proc.stdout: + line = str(line) + _LOG.debug(line) + if "[QEMU] CPU" in line: + self._queue.put(ZephyrQemuMakeResult.QEMU_STARTED) + else: + line = re.sub("[^a-zA-Z0-9 \n]", "", line) + pattern = r"recipe for target (\w*) failed" + if re.search(pattern, line, re.IGNORECASE): + self._queue.put(ZephyrQemuMakeResult.MAKE_FAILED) + self._queue.put(ZephyrQemuMakeResult.EOF) + + def _wait_for_qemu(self): + threading.Thread(target=self._qemu_check_stdout, daemon=True).start() + while True: + try: + item = self._queue.get(timeout=120) + except Exception: + raise TimeoutError("QEMU setup timeout.") + + if item == ZephyrQemuMakeResult.QEMU_STARTED: + break + + if item in [ZephyrQemuMakeResult.MAKE_FAILED, ZephyrQemuMakeResult.EOF]: + raise RuntimeError("QEMU setup failed.") + + raise ValueError(f"{item} not expected.") + class ZephyrDebugger(debugger.GdbDebugger): """A Zephyr debugger implementation.""" diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 9460e23a5357..7378ed6beb8a 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -20,45 +20,6 @@ from tvm.target import Target -@tvm._ffi.register_func("relay.backend.lower") -def lower(sch, inputs, func_name, source_func): - """Backend function for lowering. - - Parameters - ---------- - sch : tvm.te.Schedule - The schedule. - - inputs : List[tvm.te.Tensor] - The inputs to the function. - - func_name : str - The name of the function. - - source-func : tvm.relay.Function - The source function to be lowered. - - Returns - ------- - mod : tvm.IRModule - The result of lowering. - """ - # pylint: disable=broad-except, import-outside-toplevel - import traceback - - try: - f = tvm.driver.lower(sch, inputs, name=func_name) - # logging.debug("lower function %s", func_name) - # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) - except Exception: - msg = traceback.format_exc() - msg += "Error during compile function\n" - msg += "-----------------------------\n" - msg += source_func.astext() - raise RuntimeError(msg) - return f - - @tvm._ffi.register_func("relay.backend.build") def build(mod, target, target_host=None): """Backend build function. diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index b368f4e5175e..320a599d5d91 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -547,11 +547,11 @@ class CallPattern(DFPattern): Parameters ---------- - op: realy.dataflow_pattern.DFPattern + op: relay.dataflow_pattern.DFPattern The operation to be called. - args: List[realy.dataflow_pattern.DFPattern] - The arguments to the call. + args: List[relay.dataflow_pattern.DFPattern] + The arguments to the call or None to match any arguments. """ @@ -569,10 +569,10 @@ class FunctionPattern(DFPattern): Parameters ---------- - params: List[realy.dataflow_pattern.DFPattern] - The parameters to the Function. + params: List[relay.dataflow_pattern.DFPattern] + The parameters to the Function or None to match any parameters. - body: realy.dataflow_pattern.DFPattern + body: relay.dataflow_pattern.DFPattern The body fo the Function """ @@ -886,7 +886,7 @@ def partition( Parameters ---------- - partion: tvm.relay.dataflow_pattern.DFPattern + pattern: tvm.relay.dataflow_pattern.DFPattern The pattern to match expr : tvm.relay.Expr The expression to split into functions diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index acc33d73e826..c7d25d09859d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -391,10 +391,10 @@ def slice(self, inputs, input_types): stride = inputs[4] target_begin, is_begin_const = try_infer_value( - inputs[2], lambda ret: np.asscalar(ret.astype(np.int)) + inputs[2], lambda ret: ret.astype(np.int).item(0) ) target_end, is_end_const = try_infer_value( - inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) + inputs[3], lambda ret: ret.astype(np.int).item(0) ) # A fast path when slicing is nop. @@ -1306,7 +1306,7 @@ def view(self, inputs, input_types): for i, shape in enumerate(shape_inp): if isinstance(shape, _expr.Expr): val = _infer_value_simulated(shape, {}) - new_shape[i] = np.asscalar(val.numpy()) + new_shape[i] = val.numpy().item(0) return _op.transform.reshape(data, new_shape) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 4a3db910c8a2..f614982aac6c 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -460,7 +460,7 @@ def _get_numpy(relay_const_scalar): def _get_scalar(relay_const_scalar): - return np.asscalar(_get_numpy(relay_const_scalar)) + return _get_numpy(relay_const_scalar).item(0) def _do_bias_and_requantize( diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index c9b37881e208..612ea908ced2 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1146,22 +1146,23 @@ def _impl(inputs, attr, params, mod): orig_shape_x = _infer_shape(input_x, mod) orig_shape_y = _infer_shape(input_y, mod) ndim = len(orig_shape_x) + ndim_y = len(orig_shape_y) is_static = not check_symbolic_shape(orig_shape_x) - if ndim > 3 and not is_static: - shape_of_x = list_shape_of(inputs[0], ndim) - shape_of_y = list_shape_of(inputs[1], ndim) - # reshape n-dimensional batch matmul into 3d if ndim > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] if is_static: num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + if ndim_y > 2: + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + elif ndim_y == 2: + new_shape_y = (1, orig_shape_y[-2], orig_shape_y[-1]) else: # handle dynamic shape (dyn.reshape op) - # new shape = [prod(shape[:-2]), -2, -1] + shape_of_x = list_shape_of(inputs[0], ndim) + shape_of_y = list_shape_of(inputs[1], ndim) new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] for i in range(ndim - 2): @@ -1172,7 +1173,8 @@ def _impl(inputs, attr, params, mod): input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) - + elif ndim_y == 2: + input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"] input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x @@ -2898,6 +2900,7 @@ def _impl(inputs, attr, params, mod): "GreaterEqual": _broadcast("greater_equal"), "Identity": _identity(), "IdentityN": _identityn(), + "InvertPermutation": AttrCvt("invert_permutation"), "IsFinite": AttrCvt("isfinite"), "IsInf": AttrCvt("isinf"), "IsNan": AttrCvt("isnan"), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2d7613d046af..7e2173943265 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3496,7 +3496,7 @@ def get_scalar_from_constant(expr): assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype( np.float32 ), "value must be float32/int32" - return np.asscalar(value) + return value.item(0) def get_tensor_from_constant(expr): diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 4c693fe64ee0..2e509a111c4a 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -29,6 +29,7 @@ debug, register_external_compiler, register_fake_quantization_to_integer, + register_mixed_precision_conversion, ) from . import strategy diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index f87b5ed0b8ef..bee188f19364 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -178,6 +178,10 @@ def compute_unique(attrs, inputs, output_type): _reg.register_strategy("unique", strategy.unique_strategy) +# invert_permutation +_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy) +_reg.register_shape_func("invert_permutation", False, elemwise_shape_func) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 27e581b942af..9f3c1cdec0f7 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -397,6 +397,14 @@ def qnn_dense(expr): return True +def check_dilation(attrs): + """Prevents offloading if dilation other than (1, 1)""" + if not isinstance(attrs, relay.op.op_attrs.GlobalPool2DAttrs): + if not (len(attrs.dilation) == 2 and attrs.dilation[0] == 1 and attrs.dilation[1] == 1): + return False + return True + + @tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib") def max_pool2d(expr): """Check if the external ACL codegen for maxpool2d should be used.""" @@ -406,7 +414,7 @@ def max_pool2d(expr): typ = args[0].checked_type if typ.dtype not in ["float32", "uint8"]: return False - return True + return check_dilation(attrs) @tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib") @@ -424,7 +432,7 @@ def avg_pool2d(expr, from_quantized_composite=False): if attrs.layout != "NHWC": return False - return True + return check_dilation(attrs) @tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib") diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 5b7fd32add4c..2071a43f828b 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -26,6 +26,7 @@ from .. import op as reg from .. import strategy from ..op import OpPattern +from .image import resize # resize @@ -58,6 +59,36 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.resize") +@reg.register_convert_op_layout("image.resize") +def convert_image_resize(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current resize op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data input. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + new_attrs = dict(attrs) + assert len(desired_layouts) == 1, "Only one desired layout is expected" + desired_layout = str(desired_layouts[0]) + assert desired_layout != "default", "Layout cannot be default" + new_attrs["layout"] = desired_layout + return resize(*inputs, **new_attrs) + + @script def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index ccf011819a97..0d90a5cdeafa 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -18,10 +18,11 @@ """The base node types for the Relay language.""" import tvm._ffi import tvm.ir -from tvm.driver import lower, build -from tvm.target import get_native_generic_func, GenericFunc -from tvm.runtime import Object import tvm.ir._ffi_api +from tvm.driver import build, lower +from tvm.runtime import Object +from tvm.target import GenericFunc, get_native_generic_func + from . import _make @@ -457,6 +458,32 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) +def register_mixed_precision_conversion(op_name, func=None, level=10): + """Register mixed precision conversion function for an op + + Given an op the function should return information on how the value should be + converted. Specifically the function should take a call node and the target + mixed precision datatype (e.g. FP16) and return the conversion category + (see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation + and output datatype of the operation in the mixed precision dtype space. + + Parameters + ---------- + op_name : str + The name of the operator + + func: function (call_node: relay.Call, target_dtype: string) + -> [conversion category, accumulation dtype, output dtype]: [int, string, string] + A function which given a call_node and target_dtype (e.g. FP16) returns the + conversion category and associated accumulation/output of the operation + when transformed into the mixed precision dtype space. + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FTVMMixedPrecisionConversionType", func, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ef0c355ffe47..2f49aa4a89c7 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1155,3 +1155,15 @@ def schedule_transpose_cuda(attrs, outs, target): ): return topi.cuda.schedule_transpose(outs) return schedule_injective(attrs, outs, target) + + +@invert_permutation_strategy.register(["cuda", "gpu"]) +def invert_permutation_strategy_cuda(attrs, inputs, out_type, target): + """invert_permutation cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_invert_permutation(topi.cuda.invert_permutation), + wrap_topi_schedule(topi.cuda.vision._default_schedule), + name="invert_permutation.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index ed3bc4af8d3d..edb4556a554b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1643,3 +1643,25 @@ def schedule_transpose(attrs, outs, target): """schedule transpose""" with target: return schedule_injective(attrs, outs, target) + + +# invert_permutation +def wrap_compute_invert_permutation(topi_compute): + """wrap invert_permutation topi compute""" + + def _compute_invert_permutation(attrs, inputs, out_type): + return [topi_compute(inputs[0])] + + return _compute_invert_permutation + + +@override_native_generic_func("invert_permutation_strategy") +def invert_permutation_strategy(attrs, inputs, out_type, target): + """invert_permutation generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_invert_permutation(topi.invert_permutation), + wrap_topi_schedule(topi.generic.schedule_injective), + name="invert_permutation.generic", + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 049ddc9622ba..9cb50ed6548a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -48,12 +48,15 @@ def cast(data, dtype): def cast_like(data, dtype_like): """Cast input tensor to data type of another tensor. + Parameters ---------- data : relay.Expr The input data to the operator. + dtype_like: relay.Expr The tensor to cast to. + Returns ------- result : relay.Expr @@ -1717,3 +1720,31 @@ def unique(data, is_sorted=True, return_counts=False): if return_counts: return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5) return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) + + +def invert_permutation(data): + """Computes the inverse permutation of data. + This operation computes the inverse of an index permutation. + It takes a 1-D integer tensor x, which represents the indices of a zero-based + array and swaps each value with its index position. + + For an output tensor y and an input tensor x, this operation computes the following: + y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + + Parameters + ---------- + data : relay.Expr + The source data to be invert permuated. + + Returns + ------- + ret : relay.Expr + Invert permuated data. Has the same type as data. + + Examples + -------- + .. code-block:: python + data = [3, 4, 0, 2, 1] + relay.invert_permutation(data) = [2, 4, 3, 0, 1] + """ + return _make.invert_permutation(data) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 3c4d2ddcd0ec..961517f863fb 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -91,7 +91,7 @@ def get_scalar_from_constant(expr): assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype( np.float32 ), "value must be float32/int32" - return np.asscalar(value) + return value.item(0) # Helper function for lowering in the abscence of fast Int8 arithmetic units. diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py new file mode 100644 index 000000000000..6aa3ac09cfee --- /dev/null +++ b/python/tvm/relay/transform/mixed_precision.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long,unused-argument +"""Default behavior for ops in mixed_precision pass. Import this file to use.""" +from typing import List + +from tvm import relay +from tvm.relay.op import register_mixed_precision_conversion + +# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +# numerical reasons. +MIXED_PRECISION_ALWAYS = 0 +MIXED_PRECISION_FOLLOW = 1 +MIXED_PRECISION_NEVER = 2 + +# Default lists inspired from TF's classifications: +# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. +DEFAULT_ALWAYS_LIST = [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + # "nn.batch_matmul", # Handled by a special case +] +DEFAULT_FOLLOW_LIST = [ + # These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + "concatenate", + "zeros", + "split", + "squeeze", + "transpose", + "expand_dims", + "reshape", + "dyn.reshape", + "broadcast_to_like", + "dyn.broadcast_to", + "strided_slice", + "dyn.strided_slice", + "take", + "argwhere", + "where", + "tile", + "dyn.tile", + "scatter", + "full", + "dyn.full", + # Comparison + "less", + "greater", + "less_equal", + "greater_equal", + # By definition copy and cast will depend on inputs for output. + "copy", + "cast", + "cast_like", + # Simple arithmetic + "add", + "subtract", + "multiply", + "divide", + "nn.bias_add", + "nn.batch_norm", + "sum", + "mean", + "sqrt", + "shape_of", + # Simple activations + "max", + "min", + "maximum", + "minimum", + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + # Complicated activations which saturate in a narrow range + "sigmoid", + "tanh", + # Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + # "nn.global_max_pool1d", # does not exist yet + "nn.global_max_pool2d", + # "nn.global_max_pool3d", # does not exist yet + # "nn.global_avg_pool1d", # does not exist yet + "nn.global_avg_pool2d", + # "nn.global_avg_pool3d", # does not exist yet + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", +] +DEFAULT_NEVER_LIST = [ + # In general if |f(x)| >> |x| for expected inputs then put the op here. + "exp", + "power", + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + "nn.l2_normalize", + # Error function doesn't seem to be able to be lowered into fp16 version in llvm. + # Move to follow list when it does. + "erf", +] + + +# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType +def register_func_to_op_list(list_ops: List): + def decorator(func): + for op_name in list_ops: + register_mixed_precision_conversion(op_name, func=func) + + return decorator + + +def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: + """A function which returns output dtypes in a way which works for most ops. + + Parameters + --------- + call_node: relay.Call + The call node containing the op. + mixed_precision_type: str + The target type to run the operation in. + Returns + ------- + output_dtypes : [str, str] + A list of two strings. The first represents the datatype used for accumulation + in the operation. The second represents the actual output datatype. + """ + # Assume support accumulation dtypes <---> has out_dtype attr. + # This is because there is no better way right now to tell which ops support accumulating + # at different data types. + # Some discussion here about making this better is here: + # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo + if hasattr(call_node.attrs, "out_dtype"): + return ["float32", mixed_precision_type] + + # [accumulation_dtype, output_dtype] for the operations + return [mixed_precision_type, mixed_precision_type] + + +# Functions for FTVMMixedPrecisionConversionType which +# Take in CallNodes and a DType and returns a conversion type, +# an accumulation dtype, and an output_dtype. +@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) +def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) + + +@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) +def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) + + +@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) +def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) + + +@register_mixed_precision_conversion("nn.batch_matmul") +def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List: + # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. + # Batched matmul has inconsistent support for mixed precision operations. + # Many schedules ignore the out_dtype attribute which leads to errors when + # input types do not match the out_dtype. Therefore, accumulate to output_dtype. + return [MIXED_PRECISION_ALWAYS, "float16", "float16"] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e045abab6c..fa7f4c4db644 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -18,16 +18,15 @@ """ Relay pass transformation infrastructure. """ -import types -import inspect import functools +import inspect +import types import warnings import tvm.ir -from tvm import te +from tvm import relay, te from tvm.runtime import ndarray as _nd -from tvm import relay from . import _ffi_api @@ -1168,7 +1167,7 @@ def AnnotateSpans(): Returns ------- ret : tvm.transform.Pass - The regsistered AnnotateSpans pass. + The registered AnnotateSpans pass. """ return _ffi_api.AnnotateSpans() @@ -1199,3 +1198,29 @@ def FakeQuantizationToInteger(): The registered SimplifyExpr pass. """ return _ffi_api.FakeQuantizationToInteger() + + +def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): + """ + Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version + where as many operations as possible are in the target mixed_precision_type. + + Parameters + ---------- + mixed_precision_type: str + The target datatype to transform operations in the graph to use. + + missing_op_mode: int + Determines how to handle ops not registered with FTVMMixedPrecisionConversionType + 0: Does not allow any missing ops. Will throw errors when encountering any. + 1: Allow missing ops but emit warnings. + 2: Allow missing ops and silently ignore them. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass. + """ + if missing_op_mode < 0 or missing_op_mode > 2: + raise ValueError("Missing op mode is either 0, 1, or 2") + return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index c07e88b59e37..0b49b675d77d 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -143,7 +143,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): listen_sock: Socket The socket used by listening process. - tracker_conn : connnection to tracker + tracker_conn : connection to tracker Tracker connection ping_period : float, optional @@ -216,7 +216,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): if magic != base.RPC_TRACKER_MAGIC: raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr)) # report status of current queue - cinfo = {"key": "server:" + rpc_key} + cinfo = {"key": "server:" + rpc_key, "addr": (custom_addr, port)} base.sendjson(tracker_conn, [TrackerCode.UPDATE_INFO, cinfo]) assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 9506a52dd55d..74c1f7ac07aa 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -16,7 +16,7 @@ # under the License. """RPC Tracker, tracks and distributes the TVM RPC resources. -This folder implemements the tracker server logic. +This folder implements the tracker server logic. Note ---- @@ -67,7 +67,7 @@ class Scheduler(object): - """Abstratc interface of scheduler.""" + """Abstract interface of scheduler.""" def put(self, value): """Push a resource into the scheduler. @@ -167,7 +167,7 @@ def __init__(self, tracker, sock, addr): self._msg_size = 0 self._addr = addr self._init_req_nbytes = 4 - self._info = {"addr": addr} + self._info = {} # list of pending match keys that has not been used. self.pending_matchkeys = set() self._tracker._connections.add(self) @@ -272,7 +272,11 @@ def _cb(value): else: self.ret_value(TrackerCode.FAIL) elif code == TrackerCode.UPDATE_INFO: - self._info.update(args[1]) + info = args[1] + assert isinstance(info, dict) + if info["addr"][0] is None: + info["addr"][0] = self._addr[0] + self._info.update(info) self.ret_value(TrackerCode.SUCCESS) elif code == TrackerCode.SUMMARY: status = self._tracker.summary() diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 76ddbb1de697..38ff1b71f07d 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -185,6 +185,11 @@ def opaque_axis(begin, end, span): return get_axis(begin, end, "opaque", span) +@register +def Select(cond, if_body, else_body, span): # pylint: disable=invalid-name + return tvm.tir.Select(cond, if_body, else_body, span) + + @register class EvaluateIntrin(Intrin): def __init__(self): diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 393d0395aad6..49f71041590b 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -749,14 +749,17 @@ def f(): node.call.func_name.span, ) - if isinstance(func, Intrin) and func.stmt: - return call_with_error_reporting( - self.report_error, - node.call.func_name.span, - func.handle, - arg_list, - node.call.func_name.span, - ) + if isinstance(func, Intrin): + if func.stmt: + return call_with_error_reporting( + self.report_error, + node.call.func_name.span, + func.handle, + arg_list, + node.call.func_name.span, + ) + else: + self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span) elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: func.enter_scope(node, self.context, arg_list, node.call.func_name.span) func.body = self.parse_body(node) @@ -765,7 +768,11 @@ def f(): func.handle(node, self.context, arg_list, node.call.func_name.span) return - self.report_error(f"Invalid Expr stmt {type(func).__name__}.", node.call.func_name.span) + self.report_error( + "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a " + f"special statement, but got {type(func).__name__}.", + node.call.func_name.span, + ) def transform_Slice(self, node): start = self.transform(node.start) @@ -785,7 +792,9 @@ def transform_Subscript(self, node): symbol = self.transform(node.params[0]) if symbol is None: - self.report_error(f"Variable {node.value.id} is not defined.", node.params[0].span) + self.report_error( + f"Variable {node.params[0].id.name} is not defined.", node.params[0].span + ) indexes = [self.transform(x) for x in node.params[1].values] if isinstance(symbol, tvm.tir.expr.Var): @@ -844,7 +853,7 @@ def transform_Attr(self, node): self.report_error("Unsupported Attribute expression.", node.object.span) if not hasattr(symbol, node.field.name): self.report_error( - f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span + f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span ) res = getattr(symbol, node.field.name) return res diff --git a/python/tvm/script/ty.py b/python/tvm/script/ty.py index 1d7871624eb5..960e090a163c 100644 --- a/python/tvm/script/ty.py +++ b/python/tvm/script/ty.py @@ -62,7 +62,14 @@ def __getitem__(self, vtypes): return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) +int8 = ConcreteType("int8") +int16 = ConcreteType("int16") int32 = ConcreteType("int32") +int64 = ConcreteType("int64") +float16 = ConcreteType("float16") +float32 = ConcreteType("float32") +float64 = ConcreteType("float64") +boolean = ConcreteType("bool") handle = ConcreteType("handle") Ptr = GenericPtrType() Tuple = GenericTupleType() diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 63c7c9308284..a199534ccb51 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -117,9 +117,15 @@ def conv2d_cudnn( else: dtype = data.dtype - cfg.define_knob("algo", range(8)) - if cfg.is_fallback: # Let CUDNN choose the best algo - cfg["algo"] = OtherOptionEntity(-1) + cfg.define_knob("algo", range(cudnn.algo_to_index("fwd", "CUDNN_CONVOLUTION_FWD_ALGO_COUNT"))) + if cfg.is_fallback: + if cudnn.exists(): + # Let CUDNN choose the best algo, based on benchmarks run + # on the local machine. In the future, this should be + # based on parameters stored in the Target. + cfg["algo"] = OtherOptionEntity(-1) + else: + cfg["algo"] = OtherOptionEntity(0) return cudnn.conv_forward( data, diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 067f27262b06..4863a06b728d 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -270,6 +270,60 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return None +def _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor): + # Pad batch size + if db != 0: + data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, db), (0, 0))) + + # Pad input channel + if di != 0: + data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di))) + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, di))) + + # Pad output channel + if do != 0: + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, do), (0, 0))) + + if do != 0: + new_out_channel = out_channel + do + new_attrs["channels"] = new_out_channel + + out = relay.nn.conv2d(data, kernel, **new_attrs) + + if db != 0 or do != 0: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) + + return out + + +def _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor): + # Pad batch size + if db != 0: + data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0))) + + # Pad input channel + if di != 0: + data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di))) + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0))) + + # Pad output channel + if do != 0: + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do))) + + if do != 0: + new_out_channel = out_channel + do + new_attrs["channels"] = new_out_channel + + out = relay.nn.conv2d(data, kernel, **new_attrs) + + if db != 0 or do != 0: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) + + return out + + @conv2d_legalize.register("cuda") def _conv2d_legalize(attrs, inputs, arg_types): """Legalizes Conv2D op. @@ -347,7 +401,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): else: out = relay.nn.conv2d(data, kernel, **new_attrs) return out - elif data_dtype in ["float16"]: # todo: support int8/int4 + if data_layout == "NHWC" and kernel_layout == "HWIO": batch = data_tensor.shape[0].value in_channel = data_tensor.shape[3].value @@ -361,7 +415,10 @@ def _conv2d_legalize(attrs, inputs, arg_types): # no need to pad return None - (db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel) + candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] + (db, di, do), extra_flops = pad_to_tensorcore( + batch, in_channel, out_channel, candidates + ) if extra_flops > 2: logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops) @@ -369,28 +426,100 @@ def _conv2d_legalize(attrs, inputs, arg_types): logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) - # Pad batch size - if db != 0: - data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0))) + return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) - # Pad input channel - if di != 0: - data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di))) - kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0))) + if data_layout == "HWNC" and kernel_layout == "HWOI": + batch = data_tensor.shape[2].value + in_channel = data_tensor.shape[3].value + out_channel = kernel_tensor.shape[2].value - # Pad output channel - if do != 0: - kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do))) + if batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0: + return None - if do != 0: - new_out_channel = out_channel + do - new_attrs["channels"] = new_out_channel + candidates = [(8, 16, 32)] + (db, di, do), extra_flops = pad_to_tensorcore( + batch, in_channel, out_channel, candidates + ) - out = relay.nn.conv2d(data, kernel, **new_attrs) + if extra_flops > 2: + logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) - if db != 0 or do != 0: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) + return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) + + elif data_dtype in ["float16"]: + if data_layout == "NHWC" and kernel_layout == "HWIO": + batch = data_tensor.shape[0].value + in_channel = data_tensor.shape[3].value + out_channel = kernel_tensor.shape[3].value + + if ( + (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0) + or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0) + or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0) + ): + # no need to pad + return None + + candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] + (db, di, do), extra_flops = pad_to_tensorcore( + batch, in_channel, out_channel, candidates + ) + + if extra_flops > 2: + logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) + + return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) + + elif data_dtype in ["int4", "uint4"]: + if data_layout == "NHWC" and kernel_layout == "HWIO": + batch = data_tensor.shape[0].value + in_channel = data_tensor.shape[3].value + out_channel = kernel_tensor.shape[3].value + + if ( + (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0) + or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0) + or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0) + ): + # no need to pad + return None + + candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] + (db, di, do), extra_flops = pad_to_tensorcore( + batch, in_channel, out_channel, candidates + ) + + if extra_flops > 2: + logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) + + return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) + + if data_layout == "HWNC" and kernel_layout == "HWOI": + batch = data_tensor.shape[2].value + in_channel = data_tensor.shape[3].value + out_channel = kernel_tensor.shape[2].value + + if batch % 8 == 0 and in_channel % 32 == 0 and out_channel % 8 == 0: + return None + + candidates = [(8, 32, 8)] + (db, di, do), extra_flops = pad_to_tensorcore( + batch, in_channel, out_channel, candidates + ) + + if extra_flops > 2: + logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) + + return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor) - return out return None diff --git a/python/tvm/topi/cuda/conv3d.py b/python/tvm/topi/cuda/conv3d.py index 530df31ed3dc..51f1f7a27c7e 100644 --- a/python/tvm/topi/cuda/conv3d.py +++ b/python/tvm/topi/cuda/conv3d.py @@ -221,6 +221,16 @@ def conv3d_cudnn( * ((KW - 1) * dilation_w + 1) ) + cfg.define_knob("algo", range(cudnn.algo_to_index("fwd", "CUDNN_CONVOLUTION_FWD_ALGO_COUNT"))) + if cfg.is_fallback: + if cudnn.exists(): + # Let CUDNN choose the best algo, based on benchmarks run + # on the local machine. In the future, this should be + # based on parameters stored in the Target. + cfg["algo"] = OtherOptionEntity(-1) + else: + cfg["algo"] = OtherOptionEntity(0) + return cudnn.conv_forward( data, kernel, @@ -229,7 +239,7 @@ def conv3d_cudnn( [dilation_d, dilation_h, dilation_w], conv_mode=1, tensor_format=tensor_format, - algo=-1, # let CUDNN choose the best algo + algo=cfg["algo"].val, conv_dtype=dtype, ) diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index ceab71640533..b9d02d9c81d8 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -37,7 +37,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): all_reduce = False num_thread = 32 target = tvm.target.Target.current() - if target and target.kind.name == "opencl": + if target and (target.kind.name == "opencl" or target.kind.name == "metal"): # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # don't know why num_thread = 16 diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index aec7acbfde56..eb7c71ddf1c9 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -71,7 +71,8 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): # no need to pad return None - (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N) + candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] + (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) if extra_flops > 2: logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) @@ -145,7 +146,8 @@ def _dense_legalize(attrs, inputs, arg_types): # no need to pad return None - (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N) + candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] + (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) if extra_flops_ratio > 2: logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) @@ -171,10 +173,8 @@ def _dense_legalize(attrs, inputs, arg_types): return None -def pad_to_tensorcore(M, K, N): +def pad_to_tensorcore(M, K, N, candidates): """pad shape to enable tensorcore""" - candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - flops = M * K * N extra_flops = math.inf best_pad = (0, 0, 0) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index 89caf94bbbc1..16b1273def47 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """CUDA implementations of transforms""" - +import tvm from ... import te from ...target import Target from ..utils import traverse_inline @@ -65,3 +65,74 @@ def _callback(op): s[c].bind(ao, thread_y) traverse_inline(s, out.op, _callback) + + +def _invert_permutation_ir(data, out): + """Low level IR to get invert_permutation. + + Parameters + ---------- + data : Buffer + Input data. 1-D Buffer with shape [elem_num]. + + out : Buffer + 1D buffer for invert permutation result with the same shape with data. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + elem_num = data.shape[0] + + irb = tvm.tir.ir_builder.create() + data = irb.buffer_ptr(data) + out = irb.buffer_ptr(out) + + max_threads = int(Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = elem_num // max_threads + 1 + thread_x = te.thread_axis("threadIdx.x") + block_x = te.thread_axis("blockIdx.x") + irb.scope_attr(thread_x, "thread_extent", nthread_tx) + irb.scope_attr(block_x, "thread_extent", nthread_bx) + tid = block_x * max_threads + thread_x + + with irb.if_scope(tid < elem_num): + r_ind = data[tid] + out[r_ind] = tid + return irb.get() + + +def invert_permutation(data): + """Compute definition of invert_permutation. + For an output tensor y and an input tensor x, this operation computes the following: + + y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + + Parameters + ---------- + data : tvm.te.Tensor + 1-D tensor + + Returns + ------- + out : tvm.te.Tensor + """ + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) + + out = te.extern( + [data.shape], + [data], + lambda ins, outs: _invert_permutation_ir(ins[0], outs[0]), + in_buffers=[ + data_buf, + ], + out_buffers=[ + out_buf, + ], + name="invert_permutation", + tag="invert_permutation_gpu", + ) + return out diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 130eb4b69844..3f72bdc4b667 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -18,13 +18,15 @@ # pylint: disable=unused-argument, redefined-builtin """Conv2D operators""" from __future__ import absolute_import as _abs + from collections import namedtuple + import tvm -from tvm import te, auto_scheduler +from tvm import auto_scheduler, te +from ..utils import get_const_int, get_const_tuple, simplify, tag from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify, get_const_tuple, get_const_int, tag from .winograd_util import winograd_transform_matrices # workload description of conv2d @@ -548,7 +550,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ow * WSTR + kw * dilation_w, idxmod(ic, ic_bn), ].astype(out_dtype) - * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block], + * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype( + out_dtype + ), axis=[ic, kh, kw], ), name="conv2d_NCHWc", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b4d0167be2b1..45756eadbcdb 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import topi +from tvm.te import hybrid from . import cpp from . import tag from .utils import within_index, make_idx, const_vector @@ -941,3 +942,31 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +@hybrid.script +def invert_permutation(data): + """Computes the inverse permutation of data. + + Parameters + ---------- + data : tvm.te.Tensor + Input data + + Returns + ------- + result : tvm.te.Tensor + Output tensor + + Examples + -------- + .. code-block:: python + data = [3, 4, 0, 2, 1] + topi.invert_permutation(data) = [2, 4, 3, 0, 1] + """ + result = output_tensor(data.shape, data.dtype) + nums = data.shape[0] + for ind in range(nums): + r_ind = data[ind] + result[r_ind] = ind + return result diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index df480123375d..37bdd09d6ca6 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -139,7 +139,7 @@ def _default_batch_matmul_config(cfg, M, N, K): def batch_matmul_blas_common(cfg, x, y, out_shape, lib): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. + data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. Parameters ---------- @@ -162,10 +162,10 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) - assert XB == YB, "batch dimension doesn't match" + assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: - assert out_shape[0] == XB, "got invalid output shape" + assert out_shape[0] in (XB, YB), "got invalid output shape" assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) diff --git a/rust/tvm-macros/src/util.rs b/rust/tvm-macros/src/util.rs index 2a342bcc3453..b02e3f69b671 100644 --- a/rust/tvm-macros/src/util.rs +++ b/rust/tvm-macros/src/util.rs @@ -43,6 +43,6 @@ pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr { .. })) => s, Ok(m) => panic!("Expected a string literal, got {:?}", m), - Err(e) => panic!(e), + Err(e) => panic!("{}", e), } } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 264d5febd103..64fd6a2218aa 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -109,7 +109,7 @@ impl Object { let mut index = 0; unsafe { if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 { - panic!(crate::get_last_error()) + panic!("{}", crate::get_last_error()) } } return index; diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs index c411c0cd31a7..22e51e4e7396 100644 --- a/rust/tvm/src/ir/diagnostics/codespan.rs +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -191,7 +191,7 @@ fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) { let config = codespan_reporting::term::Config::default(); for diagnostic in diag_ctx.diagnostics.clone() { match source_map.source_map.get(&diagnostic.span.source_name) { - Err(err) => panic!(err), + Err(err) => panic!("{}", err), Ok(source) => { state.add_source(source); let diagnostic = state.to_diagnostic(diagnostic); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ff6f409db483..cd8173717d5f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -92,22 +92,62 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std offset_factor, buffer_type); } -void GetBinds(const Array& args, bool compact, +void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, Array* out_arg_list) { *out_binds = binds; - for (const auto& x : args) { - if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); - out_binds->Set(x, buf); - out_arg_list->push_back(buf); + for (const ObjectRef& x : args) { + if (const te::TensorNode* tensor_node = x.as()) { + te::Tensor x_ref = GetRef(tensor_node); + if (out_binds->find(x_ref) == out_binds->end()) { + tir::Buffer buf = + BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); + out_binds->Set(x_ref, buf); + out_arg_list->push_back(buf); + } else { + out_arg_list->push_back((*out_binds)[x_ref]); + } + } else if (x.as() || x.as()) { + out_arg_list->push_back(x); } else { - out_arg_list->push_back((*out_binds)[x]); + LOG(FATAL) + << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " + << "but got a " << x->GetTypeKey(); } } } +void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + GetBinds(ref_args, compact, binds, out_binds, out_arg_list); +} + +TVM_REGISTER_GLOBAL("driver.get_binds") + .set_body_typed([](const Array& args, bool compact, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + Map out_binds; + Array out_arg_list; + GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); + + // TVM object system doesn't have a pair object, so we'll put both ret values in an array + // and return that. + Array out_arr = {out_binds, out_arg_list}; + return out_arr; + }); + transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -127,63 +167,208 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -IRModule lower(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { - Array out_arg_list; - auto pass_ctx = transform::PassContext::Current(); - - sch = sch.normalize(); - - // Before TIR transformation. - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); - bool compact = te::VerifyCompactBuffer(stmt); - - Map out_binds; - GetBinds(args, compact, binds, &out_binds, &out_arg_list); - - // build the function - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); +Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { + transform::PassContext pass_ctx = transform::PassContext::Current(); - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + // Get any user-added passes + Array> add_lower_pass = + pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) + .value(); + + Array user_lower_phase0 = Array(); + Array user_lower_phase1 = Array(); + Array user_lower_phase2 = Array(); + Array user_lower_phase3 = Array(); + + // phase pasees is of the form + // [[phase_number, pass], [phase_number, pass]... ] + for (Array phase_pass : add_lower_pass) { + const IntImmNode* phase_num = phase_pass[0].as(); + ICHECK(phase_num) + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + int phase_num_val = phase_num->value; + + CHECK_GE(phase_num_val, 0); + + const tvm::transform::PassNode* pass_node = phase_pass[1].as(); + tvm::transform::Pass pass = GetRef(pass_node); + // Copy the pass into the correct phase + if (phase_num_val == 0) { + user_lower_phase0.push_back(pass); + } else if (phase_num_val == 1) { + user_lower_phase1.push_back(pass); + } else if (phase_num_val == 2) { + user_lower_phase2.push_back(pass); + } else if (phase_num_val >= 3) { + user_lower_phase3.push_back(pass); + } } - auto mod = IRModule(Map({{GlobalVar(name), f}})); - auto pass_list = Array(); + // Construct the pass list, inserting the user provided passes at the end of the phase + + // PHASE 0 + Array pass_list = user_lower_phase0; - // Phase 0 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - // Phase 1 + // PHASE 1 + if (for_te_schedule) { + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + } else { + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::FlattenBuffer()); + } pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::LoopPartition()); + + // Add user-defined phase-1 passes + pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); + + // PHASE 2 + if (!disable_loop_partition) { + pass_list.push_back(tir::transform::LoopPartition()); + } + pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::UnrollLoop()); - // Phase 2 + + // Add user-defined phase-2 passes + pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); + + // PHASE 3 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); + pass_list.push_back(tir::transform::HoistIfThenElse()); + + // Add user-defined phase-3 passes + pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); + if (instrument_bound_checkers) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } - // run - auto optimize = transform::Sequential(pass_list); + return pass_list; +} + +IRModule LowerWithPassList(IRModule mod, Array pass_list) { + auto optimize = tvm::transform::Sequential(pass_list); mod = optimize(std::move(mod)); return mod; } +IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { + // Convert te schedule to IRModule + Array out_arg_list; + transform::PassContext pass_ctx = transform::PassContext::Current(); + + sch = sch.normalize(); + + // Before TIR transformation. + Map bounds = te::InferBound(sch); + tir::Stmt stmt = te::ScheduleOps(sch, std::move(bounds), false); + bool compact = te::VerifyCompactBuffer(stmt); + + Map out_binds; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // Build the function + // At this point binds is only te::Tensors + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + return IRModule(Map({{GlobalVar(name), f}})); +} + +TVM_REGISTER_GLOBAL("driver.schedule_to_module") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds); + return mod; + }); + +IRModule LowerModule(IRModule mod, bool simple_mode) { + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { + return LowerModule(std::move(mod), simple_mode); +}); + +IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(name), f}})); + + // Get the pass list + Array pass_list = CreatePassList(simple_mode, false); + return LowerWithPassList(std::move(mod), pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_primfunc") + .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { + return LowerPrimFunc(std::move(func), name, simple_mode); + }); + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + Array ref_args; + for (ObjectRef x : args) { + ref_args.push_back(x); + } + return LowerSchedule(std::move(sch), ref_args, name, binds); +} + +IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, bool simple_mode) { + IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); + // Get the legacy TE pass list + Array pass_list = CreatePassList(simple_mode, true); + return LowerWithPassList(mod, pass_list); +} + +TVM_REGISTER_GLOBAL("driver.lower_schedule") + .set_body_typed([](te::Schedule sch, const Array& args, const String& name, + const Map& binds, bool simple_mode) { + std::unordered_map c_binds; + // Check to make sure binds is not null before doing the conversion; + if (binds.get() != nullptr) { + for (auto kv : binds) { + c_binds.insert({kv.first, kv.second}); + } + } + return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); + }); + std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { diff --git a/src/parser/span_check.cc b/src/parser/span_check.cc index a72db5b9e4b6..7fed3730d926 100644 --- a/src/parser/span_check.cc +++ b/src/parser/span_check.cc @@ -71,9 +71,7 @@ void SpanChecker::VisitExpr_(const MatchNode* op) { ExprVisitor::VisitExpr_(op); void SpanChecker::VisitSpan(const Span& sp) { if (!sp.defined()) { Span span; - int i = 0; for (auto spans = this->span_stack.rbegin(); spans != this->span_stack.rend(); spans++) { - i += 1; span = this->span_stack.back(); if (span.defined()) { diag_ctx.Emit(Diagnostic::Warning(span) << "found null-span, i-nodes deep from this span."); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 2de331be9581..aad42fc9b0ea 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -54,6 +54,9 @@ namespace relay { */ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { Doc doc; + if (!opt_info_memo_.insert(expr).second) { + return doc; + } // default annotations if (annotate_ == nullptr) { if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { @@ -65,7 +68,6 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { doc << annotated_expr; } } - return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 52ab701008c7..7a529cc0b914 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -176,6 +176,8 @@ class RelayTextPrinter : public ExprFunctor, std::vector doc_stack_{}; /*! \brief Set for introduced vars */ std::unordered_set var_memo_; + /*! \brief Set for exprs have been printed optional information */ + std::unordered_set opt_info_memo_; /*! \brief Map for result and memo_ diffs for visited expression */ std::unordered_map result_memo_; /*! \brief Map from Expr to Doc */ diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fa92b8f04edc..4bbe17064c87 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -115,6 +115,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitExpr_(const IntImmNode* op) override; Doc VisitExpr_(const FloatImmNode* op) override; Doc VisitExpr_(const StringImmNode* op) override; + Doc VisitExpr_(const ProducerLoadNode* op) override; Doc VisitExpr_(const BufferLoadNode* op) override; Doc VisitExpr_(const LoadNode* op) override; Doc VisitExpr_(const RampNode* op) override; @@ -387,19 +388,19 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { } else if (node->IsInstance()) { return PrintMatchBufferRegion(node.as()); } else { - meta_collector_.Collect(node); - return this->meta_.GetMetaNode(node); + LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); + return Doc(); } } Doc TVMScriptPrinter::VisitExprDefault_(const Object* op) { - meta_collector_.Collect(GetRef(op)); - return this->meta_.GetMetaNode(GetRef(op)); + LOG(FATAL) << "Do not know how to print " << op->GetTypeKey(); + return Doc(); } Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) { - meta_collector_.Collect(GetRef(op)); - return this->meta_.GetMetaNode(GetRef(op)); + LOG(FATAL) << "Do not know how to print " << op->GetTypeKey(); + return Doc(); } Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op) { @@ -414,11 +415,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiter Doc TVMScriptPrinter::VisitExpr_(const CastNode* op) { Doc doc; - if (cast(op->dtype, op->value)->IsInstance()) { - doc << Print(op->value) << ".astype(" << PrintDType(op->dtype) << ")"; - } else { - doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; - } + doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; return doc; } @@ -480,14 +477,24 @@ Doc TVMScriptPrinter::VisitExpr_(const NotNode* op) { Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op) { Doc doc; - doc << "tir.select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " + doc << "tir.Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " << Print(op->false_value) << ")"; return doc; } +Doc TVMScriptPrinter::VisitExpr_(const ProducerLoadNode* op) { + LOG(FATAL) << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to " + "lower this function first."; + return Doc(); +} + Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op) { Doc doc; - doc << Print(op->buffer) << Print(op->indices); + if (op->indices.size() == 0) { + doc << Print(op->buffer) << "[()]"; + } else { + doc << Print(op->buffer) << Print(op->indices); + } return doc; } @@ -661,12 +668,8 @@ Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) { Doc doc; - if (!is_one(op->predicate) || op->value.dtype().lanes() != 1) { - doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " - << Print(op->value) << ", " << Print(op->predicate) << ")"; - } else { - doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); - } + doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " + << Print(op->value) << ", " << Print(op->predicate) << ")"; return doc; } @@ -786,7 +789,11 @@ Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) { Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { Doc doc; - doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + if (op->indices.size() == 0) { + doc << Print(op->buffer) << "[()] = " << Print(op->value); + } else { + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + } return doc; } @@ -1051,17 +1058,21 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; - doc << Print(op->buffer) << "["; - for (size_t i = 0; i < op->region.size(); ++i) { - if (i != 0) doc << ", "; - const auto& range = op->region[i]; - if (!is_one(range->extent)) { - doc << Print(range->min) << ":" << Print(range->min + range->extent); - } else { - doc << Print(range->min); + if (op->region.size() == 0) { + doc << Print(op->buffer) << "[()]"; + } else { + doc << Print(op->buffer) << "["; + for (size_t i = 0; i < op->region.size(); ++i) { + if (i != 0) doc << ", "; + const auto& range = op->region[i]; + if (!is_one(range->extent)) { + doc << Print(range->min) << ":" << Print(range->min + range->extent); + } else { + doc << Print(range->min); + } } + doc << "]"; } - doc << "]"; return doc; } diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index acc1a9adc9f4..d8a5bb8e4f65 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -70,8 +70,8 @@ class WellFormedChecker : private MixedModeVisitor, PatternVisitor { void Bound(const Var& v) { if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) { - Illformed(Diagnostic::Error(v->span) << "the variable " << v->name_hint() - << "is bound more then once, this is not valid IR"); + Illformed(Diagnostic::Error(v->span) << "The variable " << v->name_hint() + << " is bound more than once, this is not valid IR"); } ICHECK_GE(scope.size(), 0); scope.back().insert(v); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index f72f3bd73557..29f7d30833a0 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -762,15 +762,9 @@ class CompileEngineImpl : public CompileEngineNode { all_args.push_back(arg); } // lower the function - if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); - } else { - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + std::unordered_map binds; + cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); - std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); - } value->cached_func = CachedFunc(cache_node); return value; } @@ -806,7 +800,7 @@ class CompileEngineImpl : public CompileEngineNode { With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); + cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 550afb3159fc..19b8c579cd8b 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -237,7 +237,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { // This segment would be generated in C++ because of the usage // of tvm::runtime::Array. This is not ideal, but this to demonstrate // constant copying process used packed imports in other external - // codegen. Moreover, in uTVM we dont expect this part to be generated. + // codegen. Moreover, in microTVM we dont expect this part to be generated. code_stream_ << "#ifdef __cplusplus\n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 32eecec25b06..0d575b3ec498 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -220,7 +220,7 @@ class CodegenCBase { // This segment would be generated in C++ because of the usage // of tvm::runtime::Array. This is not ideal, but this to demonstrate // constant copying process used packed imports in other external - // codegen. Moreover, in uTVM we dont expect this part to be generated. + // codegen. Moreover, in microTVM we dont expect this part to be generated. code_stream_ << "#ifdef __cplusplus\n"; code_stream_ << "int " << func_name << "_init_wrapper_(tvm::runtime::Array arr) {\n"; diff --git a/src/relay/backend/contrib/ethosn/capabilities.h b/src/relay/backend/contrib/ethosn/capabilities.h deleted file mode 100644 index cc14ca101da6..000000000000 --- a/src/relay/backend/contrib/ethosn/capabilities.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/backend/contrib/ethosn/capabilities.h - * \brief The Ethos-N processor series has four variants, the Ethos-N37, Ethos-N57, Ethos-N77 - * and the Ethos-N78. This release of the integration supports the first three variants and - * the default configuration of the fourth variant. - * Configuration information for each variant is stored as a blob in this file. These blobs - * are passed into the Ethos-N support library, which in turn uses them to optimize the - * generated command-stream appropriately for the specified variant. - */ - -#ifndef TVM_RELAY_BACKEND_CONTRIB_ETHOSN_CAPABILITIES_H_ -#define TVM_RELAY_BACKEND_CONTRIB_ETHOSN_CAPABILITIES_H_ - -#include - -#include "ethosn_api_version.h" - -namespace tvm { -namespace relay { -namespace contrib { -namespace ethosn { - -/* Ethos-N variants (Ethos-N77, Ethos-N57, Ethos-N37 and Ethos-N78) - * variant[0] - Ethos-N77 - * variant[1] - Ethos-N57 - * variant[2] - Ethos-N37 - * variant[3] - Ethos-N78 - */ -#if _ETHOSN_API_VERSION_ == 2011 -static std::vector variants[4] = { - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x02, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x02, - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - }}; -#else -static std::vector variants[4] = { - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - { - 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x02, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - }}; -#endif -} // namespace ethosn -} // namespace contrib -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_CONTRIB_ETHOSN_CAPABILITIES_H_ diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index dab0e6c42f80..97b308e51e18 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -24,7 +24,6 @@ #include #include -#include "capabilities.h" #include "codegen_ethosn.h" #include "ethosn_api.h" @@ -198,19 +197,14 @@ sl::TensorsAndId MakeOps(const sl::TensorAndId& op) { NetworkWithIDs ConstructNetworkVisitor::Construct(const Function& func) { // Initialise everything -#if _ETHOSN_API_VERSION_ >= 2011 auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relay.ext.ethos-n.options"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } -#endif NetworkWithIDs network_with_ids; -#if _ETHOSN_API_VERSION_ >= 2011 - network_ = sl::CreateNetwork(variants[cfg.value()->variant]); -#else - network_ = sl::CreateNetwork(); -#endif + network_ = sl::CreateNetwork(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); network_with_ids.network = network_; operand_table_.clear(); @@ -572,11 +566,7 @@ sl::CompilationOptions EthosnCompiler::CreateOptions() { cfg = AttrsWithDefaultValues(); } -#if _ETHOSN_API_VERSION_ >= 2011 sl::CompilationOptions options; -#else - sl::CompilationOptions options(variants[cfg.value()->variant]); -#endif options.m_Strategy0 = cfg.value()->strategy0; options.m_Strategy1 = cfg.value()->strategy1; options.m_Strategy3 = cfg.value()->strategy3; @@ -590,9 +580,6 @@ sl::CompilationOptions EthosnCompiler::CreateOptions() { options.m_BlockConfig8x32 = cfg.value()->block_config_8x32; options.m_BlockConfig8x8 = cfg.value()->block_config_8x8; options.m_EnableIntermediateCompression = cfg.value()->enable_intermediate_compression; -#if _ETHOSN_API_VERSION_ == 2008 - options.m_DebugInfo.m_DumpDebugFiles = cfg.value()->dump_debug_files; -#endif options.m_DisableWinograd = cfg.value()->disable_winograd; options.m_DebugInfo.m_DebugDir = cfg.value()->debug_dir; options.m_CompilerAlgorithm = @@ -619,20 +606,18 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } -#if _ETHOSN_API_VERSION_ >= 2011 auto ctx = transform::PassContext::Current(); auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() ? ctx -> GetConfig("relay.ext.ethos-n.options") : AttrsWithDefaultValues(); -auto m_Queries = sl::SupportQueries(variants[cfg.value()->variant]); -#endif +auto m_Queries = sl::SupportQueries(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 if (params.is_depthwise) { *rv = !err && m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, @@ -641,15 +626,6 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, params.conv_info, params.activation_info); } -#else - if (params.is_depthwise) { - *rv = !err && sl::IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); - } else { - *rv = !err && sl::IsConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); - } -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") @@ -657,13 +633,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, params.fc_info, params.input_info); -#else - *rv = !err && sl::IsFullyConnectedSupported(params.bias_info, params.weights_info, - params.fc_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") @@ -671,11 +642,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); -#else - *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") @@ -683,11 +650,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); -#else - *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") @@ -695,11 +658,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); -#else - *rv = !err && sl::IsReshapeSupported(params.new_shape, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") @@ -707,13 +666,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, params.output_quantization_info); -#else - *rv = !err && sl::IsAdditionSupported(params.lhs_info, params.rhs_info, - params.output_quantization_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") @@ -721,11 +675,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); -#else - *rv = !err && sl::IsSigmoidSupported(params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") @@ -733,11 +683,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); -#else - *rv = !err && sl::IsConcatenationSupported(params.input_infos, params.concat_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") @@ -745,11 +691,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); -#else - *rv = !err && sl::IsSplitSupported(params.input_info, params.split_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") @@ -757,11 +699,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); -#else - *rv = !err && sl::IsDepthToSpaceSupported(params.input_info, params.depth_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") @@ -769,11 +707,7 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); -#if _ETHOSN_API_VERSION_ >= 2011 *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); -#else - *rv = !err && sl::IsReluSupported(params.relu_info, params.input_info); -#endif }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index e44aa31d6b13..63ae7a3e4704 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -226,7 +226,8 @@ NetworkWithIDs ConstructNetwork(const IRModule& mod, const GlobalVar& var, const /*! \brief Attributes to store the compiler options for Ethos-N */ struct EthosnCompilerConfigNode : public tvm::AttrsNode { - int variant; + String variant; + int sram_size_bytes; bool strategy0; bool strategy1; bool strategy3; @@ -240,18 +241,14 @@ struct EthosnCompilerConfigNode : public tvm::AttrsNode storage_ids, std::vector device_types, + std::vector storage_sizes_in_bytes) { + auto n = make_object(); + n->storage_ids = std::move(storage_ids); + n->device_types = std::move(device_types); + n->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode); + +StaticMemoryPlan::StaticMemoryPlan(Map expr_to_storage_info) { + auto n = make_object(); + n->expr_to_storage_info = std::move(expr_to_storage_info); + data_ = std::move(n); +} + int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { if (expr_type->IsInstance()) { auto tuple_type = Downcast(expr_type); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4f7cbde5b62c..7d7f026c298e 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -46,6 +46,53 @@ namespace tvm { namespace relay { namespace backend { +/*! + * \brief The static storage information produced by memory planning. + */ +class StorageInfoNode : public Object { + public: + /*! \brief The set of storage ids where the expression is stored. */ + std::vector storage_ids; + /* \brief The type of "virtual devices" these expressions are stored on. */ + std::vector device_types; + /* \brief The sizes of each storage element. */ + std::vector storage_sizes_in_bytes; + + // TODO(@jroesch): expose the fields + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.StorageInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageInfoNode, Object); +}; + +/*! \brief The storage information for a single expression. */ +class StorageInfo : public ObjectRef { + public: + StorageInfo(std::vector storage_ids, std::vector device_types, + std::vector storage_sizes_in_bytes); + TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode); +}; + +/*! + * \brief The result of static memory planning. + */ +class StaticMemoryPlanNode : public Object { + public: + Map expr_to_storage_info; + + void VisitAttrs(AttrVisitor* v) { v->Visit("expr_to_storage_info", &expr_to_storage_info); } + + static constexpr const char* _type_key = "relay.StaticMemoryPlan"; + TVM_DECLARE_FINAL_OBJECT_INFO(StaticMemoryPlanNode, Object); +}; + +/*! \brief The result of running static memory planning. */ +class StaticMemoryPlan : public ObjectRef { + public: + explicit StaticMemoryPlan(Map expr_to_storage_info); + TVM_DEFINE_OBJECT_REF_METHODS(StaticMemoryPlan, ObjectRef, StaticMemoryPlanNode); +}; + struct FunctionInfoNode : public Object { Map workspace_sizes; Map io_sizes; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 6ed24d5053c4..5ce06d9fefaa 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -131,6 +131,8 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { return rhs.operator std::string() == val->value; } else if (auto* val = lhs.as()) { return rhs.operator std::string() == val->data; + } else { + ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs; } break; case kTVMObjectHandle: @@ -140,6 +142,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { } else if (auto* val = lhs.as()) { return rhs.operator String() == val->data; } + } else { + // Compare the objects for structural equality + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + return true; + } } break; default: diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 9c3d60198add..2c90d7b8a057 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -33,6 +33,31 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); +template +Array > ResizeInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + + if (new_in_layouts.defined()) { + ICHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout new_layout = new_in_layouts[0]; + Layout old_layout = old_in_layouts[0]; + if (!new_layout.Equals(old_layout) && raw_layout.Equals(old_layout) && + new_layout->axes.size() == old_layout->axes.size()) { + // Follow input layout + params->layout = new_layout.name(); + } + } + + Layout inferred_layout(params->layout); + return Array >{{inferred_layout}, {inferred_layout}}; +} + bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); @@ -102,6 +127,7 @@ RELAY_REGISTER_OP("image.resize") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) .add_type_rel("Resize", ResizeRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(Resize3dAttrs); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9361e1996796..5dc2a677f13f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3290,6 +3290,8 @@ which must just be not null. Output will have same shape as ``indices``. .set_attr("FTVMCompute", GatherCompute) .set_attr("TOpPattern", kInjective); +TVM_REGISTER_NODE_TYPE(GatherNDAttrs); + // gather_nd operator bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -3367,6 +3369,7 @@ When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2) + .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices of values to gather.") .set_support_level(3) @@ -3973,5 +3976,23 @@ RELAY_REGISTER_OP("unique") .add_type_rel("unique", UniqueRel) .set_support_level(3) .set_attr("TOpPattern", kOpaque); + +// invert_permutation +Expr MakeInvertPermutation(Expr data) { + static const Op& op = Op::Get("invert_permutation"); + return Call(op, {data}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation); + +RELAY_REGISTER_OP("invert_permutation") + .describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(1) + .set_attr("TOpPattern", kInjective) + .set_attr("TOpIsStateful", false); + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 1a3c459967bc..f883b4113656 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -146,7 +146,7 @@ class SubgraphExtractor : public ExprVisitor { return subgraph; } const AffineTypeMap GetAffineTypes() { return affine_types_; } - void VisitExpr(const Expr& expr) { + void VisitExpr(const Expr& expr) override { if (expr.as() == nullptr && expr.as() == nullptr && expr.as() == nullptr) { is_fake_quantized_ = false; diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index 7e587664b4dc..846bc08e3054 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -178,7 +178,7 @@ Expr L2NormToInferUnpack(const Attrs attrs, Expr data) { return Divide(data, sqrt); } -class InferenceSimplifier : public ExprMutator { +class InferenceSimplifier : public MixedModeMutator { public: InferenceSimplifier() : batch_norm_op_(Op::Get("nn.batch_norm")), @@ -188,8 +188,7 @@ class InferenceSimplifier : public ExprMutator { group_norm_op_(Op::Get("nn.group_norm")), l2_norm_op_(Op::Get("nn.l2_normalize")) {} - Expr VisitExpr_(const TupleGetItemNode* n) final { - Expr new_e = ExprMutator::VisitExpr_(n); + Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final { const auto* new_n = new_e.as(); if (new_n->index != 0) { return new_e; @@ -205,8 +204,7 @@ class InferenceSimplifier : public ExprMutator { return new_e; } - Expr VisitExpr_(const CallNode* n) { - auto new_n = ExprMutator::VisitExpr_(n); + Expr Rewrite_(const CallNode* n, const Expr& new_n) { if (n->op == batch_norm_op_) { ty_map_[new_n.as()->args[0]] = n->args[0]->checked_type(); } else if (n->op == layer_norm_op_) { diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc new file mode 100644 index 000000000000..ae10c937ff1c --- /dev/null +++ b/src/relay/transforms/to_mixed_precision.cc @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file to_mixed_precision.cc + * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16. + * + */ + +#include +#include +#include +#include + +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +// A callable which hashes std::pair +struct pair_hash { + template + std::size_t operator()(const std::pair& pair) const { + auto h1 = std::hash()(pair.first); + auto h2 = std::hash()(pair.second); + + // Use boost's combine_hash strategy + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); + } +}; + +// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +// numerical reasons. +enum MixedTypeConversionCategory : int { + MIXED_PRECISION_ALWAYS = 0, + MIXED_PRECISION_FOLLOW = 1, + MIXED_PRECISION_NEVER = 2 +}; + +// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype +using CachedCastNodes = std::unordered_map, Expr, pair_hash>; + +// Return array is of type : [MixedTypeConversionCategory (int), String, String] +// The fields are : [ConversionCategory, accumulation_datatype, output_datatype] +// Call is a call node, DataType is the mixed precision type +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( + const Call& call_node, const std::string& target_dtype_str)>; + +/*! \brief This class transforms the given relay module into a version where + * as many operations as possible operate in the target mixed precision dtype. + * + * Input : A Relay module with operations registered with FTVMMixedPrecisionConversionType + * functions. These describe when and how the operations will be transformed + * into the target precision dtype. + * + * Output : A Relay module with some operations transformed according to the below + * methodology. + * + * Methodology : + * 1) Each relay Op is either of conversion category ALWAYS, FOLLOW, NEVER + * defined by the associated FTVMMixedPrecisionConversionType function. + * If an operation is not registered, it by default is assumed to be + * FOLLOW. + * 2) ALWAYS operations always convert the input floating point args into + * the target mixed precision dtype. FOLLOW Ops will convert the input + * floating point args back into FP32 unless all floating point args + * are in the target mixed precision dtypes. NEVER ops will always cast + * inputs back into FP32. + * 3) Each ALWAYS Op, and FOLLOW Op with mixed precision dtype arguments + * also have an associated accumulation_dtype and output_dtype which + * describe whether a larger dtype is used to accumulate the results + * of the operation. The output_dtype meanwhile describes the dtype + * most Ops should use from this accumulator. + */ +class MixedPrecisionPass : public MixedModeMutator { + private: + /*! \brief A cache of nodes + target dtype to a cast version of the node with target dtype. */ + CachedCastNodes cast_nodes_cache_; + + /*! \brief The target datatype we want to convert to e.g. FP16 */ + const DataType mixed_precision_type_; + + /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were + * encountered. Used for emitting warnings on missing ops in the pass. + */ + std::unordered_map missing_ops_; + + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { + /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ + Attrs cur_attrs = call->attrs; + if (cur_attrs.get() != nullptr) { + // TODO(AndrewZhaoLuo): Figure out a better way to do this + // modify output_dtype attributes (accumulation dtypes for ops) + if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } + + // modify dtype attributes (creating new tensors of type dtype) + if (auto attrs = cur_attrs.as()) { + return ModifyAttrsDType(attrs, accumulation_dtype); + } + } + + return cur_attrs; + } + + template + Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with out_dtype type. + These represent accumulation dtypes for some operations e.g. + conv2d might take in fp16 and give a fp32 result. + Attrs is const because we get it as a const. + */ + DataType cur_type = (attrs->out_dtype); + ObjectPtr new_attrs = make_object(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype; + return Attrs(new_attrs); + } + + template + Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with dtype type. + This determines the output dtype for some ops. For example + zeros creates a tensor of zeros of the specified dtype. + Attrs is const because we get it as a const. + */ + DataType cur_type = (attrs->dtype); + ObjectPtr new_attrs = make_object(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype; + return Attrs(new_attrs); + } + + Type GetType(const Expr& expr) const { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main")->checked_type(); + } else { + return mod->Lookup("main").as()->body->checked_type(); + } + } + + bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const { + /* Returns whether t is a type with only target mixed precision type elements. + If ignore_non_float, then ignore non-floating types. + */ + if (const TensorTypeNode* tensor_type = t.as()) { + return (!ignore_non_float || (tensor_type->dtype).is_float()) && + tensor_type->dtype == mixed_precision_type_; + } else if (const TupleTypeNode* tuple_type = t.as()) { + for (Type t : tuple_type->fields) { + if (!IsMixedPrecisionType(t, ignore_non_float)) return false; + } + return true; + } else { + LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle"; + return false; + } + } + + Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) { + /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */ + + // If this is not a floating point type, do not cast. E.g. it might be an integer + if (!expr_dtype.is_float()) { + return expr; + } + + if (expr_dtype == wanted_dtype) { + return expr; + } + + const ExprNode* expr_node = expr.as(); + CHECK(expr_node) << "Non-expression node found in cast: " << expr; + + // Use cached result if possible. + auto search = cast_nodes_cache_.find({expr_node, wanted_dtype}); + if (search != cast_nodes_cache_.end()) { + return search->second; + } + + Expr result = Cast(expr, wanted_dtype); + cast_nodes_cache_[{expr_node, wanted_dtype}] = result; + + // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node + const ExprNode* new_expr_node = result.as(); + cast_nodes_cache_[{new_expr_node, expr_dtype}] = expr; + return result; + } + + Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) { + /* Helper for casting arguments to call_nodes handling all relevant cases. */ + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return CachedCast(expr, tensor_type->dtype, wanted_dtype); + } else if (const TupleTypeNode* tuple_type = expr_type.as()) { + Array new_expr; + bool all_same = true; + for (size_t i = 0; i < (tuple_type->fields).size(); i++) { + Expr tuple_element = GetField(expr, i); + Type tuple_element_dtype = (tuple_type->fields)[i]; + Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype); + new_expr.push_back(casted_element); + all_same &= casted_element.same_as(tuple_element); + } + return all_same ? expr : Tuple(new_expr); + } + CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!"; + return expr; + } + + std::pair, Array> CastAllArgs(const Array& cur_args, + const Array& cur_arg_types, + const DataType& wanted_dtype) { + Array new_args; + Array new_arg_types; + for (size_t i = 0; i < cur_args.size(); i++) { + Expr cur_arg = cur_args[i]; + Type cur_arg_type = cur_arg_types[i]; + Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype); + Type new_arg_type = GetType(new_arg); + new_args.push_back(new_arg); + new_arg_types.push_back(new_arg_type); + } + return {new_args, new_arg_types}; + } + + public: + using MixedModeMutator::VisitExpr_; + + explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16)) + : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) { + if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { + LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " + << mixed_precision_type_; + } + } + + Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + const CallNode* post_call_node = post.as(); + CHECK(post_call_node) << "Expected a CallNode, but got " << post; + + Expr cur_op = post_call_node->op; + + // TODO(AndrewZhaoLuo): Support ADTs + // Relay's algebraic data types are not supported yet. + ICHECK(!cur_op.as() // used to declare functions for recursion + && !cur_op.as() // constructing ADT types + && !cur_op.as()) // used for calling recursive functions + << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass."; + + // Get info on the operation being called: + // conversion category (int), accumulation dtype (str), output dtype (str) + MixedTypeConversionCategory initial_category; + DataType accumulation_dtype, output_dtype; + if (cur_op.as()) { + // Avoid messing with functions to avoid changing signature + initial_category = MIXED_PRECISION_NEVER; + accumulation_dtype = DataType::Float(32); + output_dtype = DataType::Float(32); + } else if (cur_op.as()) { + static auto attr_map = + Op::GetAttrMap("FTVMMixedPrecisionConversionType"); + Op op = Downcast(cur_op); + if (attr_map.count(op)) { + // Calculate the conversion category and dtypes from registered attribute. + FTVMMixedPrecisionConversionType func = attr_map[op]; + Array op_descriptor = + func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); + ICHECK(op_descriptor.size() == 3) + << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() + << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false); + + int64_t op_conversion_type = Downcast(op_descriptor[0])->value; + initial_category = static_cast(op_conversion_type); + accumulation_dtype = DataType(String2DLDataType(Downcast(op_descriptor[1]))); + output_dtype = DataType(String2DLDataType(Downcast(op_descriptor[2]))); + } else { + missing_ops_[op->name] += 1; + + // If not registered, by default assume is a generic FOLLOW operation. + initial_category = MIXED_PRECISION_FOLLOW; + accumulation_dtype = mixed_precision_type_; + output_dtype = mixed_precision_type_; + } + } else { + LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op; + } + + // First check if all the new mutated args are in lower precision form + Array cur_arg_types; + bool all_args_mixed_type_compatible = true; + for (Expr arg : post_call_node->args) { + Type cur_arg_type = GetType(arg); + cur_arg_types.push_back(cur_arg_type); + + if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) { + // We can cast Vars and Constants to the right types so don't care about the types. + bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) || + arg->IsInstance() || + arg->IsInstance(); + all_args_mixed_type_compatible &= is_mixed_type_compatible; + } + } + + // Determine the final category we want for conversion + MixedTypeConversionCategory final_category = initial_category; + if (initial_category == MIXED_PRECISION_FOLLOW) { + final_category = + all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; + } + + // Create the new arguments to the call. + DataType wanted_arg_dtypes = + final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32); + auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); + Array new_args = call_args_and_types.first; + Array new_arg_types; + + if (pre_call_node->op.as()) { + // Function Nodes don't store type info in the Call, it should be a [] + new_arg_types = pre_call_node->type_args; + } else { + new_arg_types = call_args_and_types.second; + } + + // Finally create the new attributes. + if (final_category == MIXED_PRECISION_ALWAYS) { + Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype); + Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); + if (accumulation_dtype != output_dtype) { + output = CastArg(output, GetType(output), output_dtype); + } + return output; + } + + return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); + } + + Expr VisitExpr_(const FunctionNode* func) final { + // Erase the ret_type annotation and let the normal pass recalculate + const_cast(func)->ret_type = Type(nullptr); + return ExprMutator::VisitExpr_(func); + } + + Expr VisitExpr_(const LetNode* op) final { + // First convert as much of the bound computation to lower precision as possible + Expr value = this->Mutate(op->value); + + // Then rewrite the var type and associated expression + Var var = Downcast(this->Mutate(op->var)); + VarNode* mutable_var = const_cast((op->var).as()); + mutable_var->type_annotation = GetType(value); + mutable_var->checked_type_ = mutable_var->type_annotation; + + // Mutate body last as it may depend on previous results + Expr body = this->Mutate(op->body); + return Let(var, value, body, op->span); + } + + // To access map of ops not registered for error reporting + friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, + int missing_op_mode); +}; + +Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) { + /* + missing_op_mode: + + 0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any. + 1: Allow missing ops but throw warnings. + 2: Allow missing ops and silently ignore them. + */ + ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) + << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; + + MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type); + auto result = converter.Mutate(expr); + + for (auto it = converter.missing_ops_.begin(); + missing_op_mode != 2 && it != converter.missing_ops_.end(); it++) { + std::string op_name = it->first; + int appear_count = it->second; + + LOG(WARNING) << "Op \"" << op_name << "\" not registered " + << "FTVMMixedPrecisionConversionType appears " << appear_count + << " times in graph."; + } + + if (converter.missing_ops_.size() != 0 && missing_op_mode == 0) { + CHECK(0) << "Missing ops were found!"; + } + return result; +} + +namespace transform { + +Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ToMixedPrecision(f, mixed_precision_type, missing_op_mode)); + }; + return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 6562d1bfc62d..5bbc536afaca 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -381,9 +381,9 @@ class ACLRuntime : public JSONRuntimeBase { void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) { std::vector padding = node.GetAttr>("padding"); std::vector strides = node.GetAttr>("strides"); + std::vector dilation = node.GetAttr>("dilation"); bool ceil_mode = std::stoi(node.GetAttr>("ceil_mode")[0]); arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides, ceil_mode); - auto attr_pool_size = node.GetAttr>("pool_size"); int pool_size_h = std::stoi(attr_pool_size[0]); int pool_size_w = std::stoi(attr_pool_size[1]); @@ -408,6 +408,8 @@ class ACLRuntime : public JSONRuntimeBase { LOG(FATAL) << "Pooling type not supported"; } + ICHECK(dilation.size() == 2 && dilation[0] == "1" && dilation[1] == "1") + << "Dilation other than (1, 1) not supported"; arm_compute::PoolingLayerInfo pool_info = arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w), arm_compute::DataLayout::NHWC, pad_stride_info, exclude_pad); diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 9ccfa5183cd6..4724b14bffa1 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -181,28 +181,48 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; + int bit_depth = sizeof(DType) * 8; + ICHECK_EQ(A->ndim, 3); ICHECK_EQ(B->ndim, 3); ICHECK_EQ(C->ndim, 3); - int batch_size = BatchCount3D(A); - ICHECK_EQ(BatchCount3D(B), batch_size); - ICHECK_EQ(BatchCount3D(C), batch_size); + + int batch_size = BatchCount3D(C); ICHECK_EQ(ElementStride(A), 1); ICHECK_EQ(ElementStride(B), 1); ICHECK_EQ(ElementStride(C), 1); + // C can never be transposed. ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; + ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - const int A_size = A->shape[1] * A->shape[2]; - const int B_size = B->shape[1] * B->shape[2]; - const int C_size = C->shape[1] * C->shape[2]; + + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } + DType* A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); DType* B_data = reinterpret_cast(static_cast(B->data) + @@ -210,9 +230,9 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { DType* C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa), - ColumnCount3D(A, transa), static_cast(alpha), B_data, B_size, - ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), - static_cast(beta), C_data, C_size, ColumnStride3D(C)); + ColumnCount3D(A, transa), static_cast(alpha), B_data, + B_stride, ColumnStride3D(B), A_data, A_stride, ColumnStride3D(A), + static_cast(beta), C_data, C_stride, ColumnStride3D(C)); } } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 1216a63703bb..015d68aec819 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -275,9 +275,8 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) ICHECK_EQ(A->ndim, 3); ICHECK_EQ(B->ndim, 3); ICHECK_EQ(C->ndim, 3); - int batch_size = BatchCount3D(A); - ICHECK_EQ(BatchCount3D(B), batch_size); - ICHECK_EQ(BatchCount3D(C), batch_size); + + int batch_size = BatchCount3D(C); ICHECK_EQ(ElementStride(A), 1); ICHECK_EQ(ElementStride(B), 1); ICHECK_EQ(ElementStride(C), 1); @@ -299,9 +298,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - const int A_size = A->shape[1] * A->shape[2]; - const int B_size = B->shape[1] * B->shape[2]; - const int C_size = C->shape[1] * C->shape[2]; + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype); cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype); @@ -325,8 +338,9 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx( hdl, CUBLASBooleanToTranspose(transb), CUBLASBooleanToTranspose(transa), ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, - cuda_in_type, ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, - beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo)); + cuda_in_type, ColumnStride3D(B), B_stride, A_data, cuda_in_type, ColumnStride3D(A), A_stride, + beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_stride, batch_size, cuda_out_type, + algo)); } // matrix multiplication for row major diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index ad3b959338bb..2d7f82694929 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -156,7 +156,9 @@ void OutputShape(int format, int dims, int groups, const int pad[], const int st dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); - if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only defined for 4d tensors"; + // Set Input CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], @@ -295,7 +297,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn") .set_body([](TVMArgs args, TVMRetValue* ret) { int format = args[0]; int dims = args[1]; diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index da67c2e1a9a5..a320c9236ceb 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -99,16 +99,38 @@ CuDNNThreadEntry::CuDNNThreadEntry() { auto func = runtime::Registry::Get("device_api.cuda"); void* ret = (*func)(); cuda_api = static_cast(ret); - CUDNN_CALL(cudnnCreate(&handle)); + + // If no CuDNN-capable device is present, allow the CuDNNThreadEntry + // object to be created. This is needed for + // CuDNNThreadEntry::exists. + { + cudnnStatus_t create_res = cudnnCreate(&handle); + if (create_res == CUDNN_STATUS_NOT_INITIALIZED) { + return; + } + CUDNN_CALL(create_res); + } + CUDNN_CALL(cudnnSetStream(handle, stream)); conv_entry.cuda_api = cuda_api; } -CuDNNThreadEntry::~CuDNNThreadEntry() { CUDNN_CALL(cudnnDestroy(handle)); } +CuDNNThreadEntry::~CuDNNThreadEntry() { + if (handle) { + CUDNN_CALL(cudnnDestroy(handle)); + } +} typedef dmlc::ThreadLocalStore CuDNNThreadStore; -CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { return CuDNNThreadStore::Get(); } +CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(bool check_exists) { + auto* res = CuDNNThreadStore::Get(); + if (check_exists) { + ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; + } + + return res; +} // ConvEntry @@ -148,5 +170,9 @@ SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_des SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool { + return CuDNNThreadEntry::ThreadLocal(false)->exists(); +}); + } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 72380b64121a..01b92d61e66e 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -93,11 +93,14 @@ struct SoftmaxEntry { struct CuDNNThreadEntry { CuDNNThreadEntry(); ~CuDNNThreadEntry(); + + bool exists() const { return handle; } + cudnnHandle_t handle{nullptr}; ConvEntry conv_entry; SoftmaxEntry softmax_entry; runtime::DeviceAPI* cuda_api{nullptr}; - static CuDNNThreadEntry* ThreadLocal(); + static CuDNNThreadEntry* ThreadLocal(bool check_exists = true); }; // CuDNNThreadEntry } // namespace contrib diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h index a7a57ff422e3..341062f1c492 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.h +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.h @@ -31,6 +31,7 @@ #include #include "../tflite/tflite_runtime.h" +#include "edgetpu.h" namespace tvm { namespace runtime { @@ -43,6 +44,14 @@ namespace runtime { */ class EdgeTPURuntime : public TFLiteRuntime { public: + /*! + * \brief Destructor of EdgeTPURuntime. + * + * NOTE: tflite::Interpreter member should be destruct before the EdgeTpuContext member + * destruction. If the order is reverse, occurs SEGV in the destructor of tflite::Interpreter. + */ + ~EdgeTPURuntime() { interpreter_.reset(); } + /*! * \return The type key of the executor. */ diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index 38c53d273a6e..f458a2f08002 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -71,8 +71,8 @@ LIBS = \ src/runtime/crt/aot_executor \ src/runtime/crt/graph_executor_module \ src/runtime/crt/memory \ - src/runtime/crt/utvm_rpc_common \ - src/runtime/crt/utvm_rpc_server + src/runtime/crt/microtvm_rpc_common \ + src/runtime/crt/microtvm_rpc_server $(foreach lib,$(LIBS),$(eval $(call LIB_template,$(lib)))) diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index 0b0c81169756..65027dd67e8c 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -25,8 +25,8 @@ #include #include #include +#include #include -#include #include #include @@ -42,7 +42,7 @@ using namespace std::chrono; extern "C" { -ssize_t UTvmWriteFunc(void* context, const uint8_t* data, size_t num_bytes) { +ssize_t MicroTVMWriteFunc(void* context, const uint8_t* data, size_t num_bytes) { ssize_t to_return = write(STDOUT_FILENO, data, num_bytes); fflush(stdout); fsync(STDOUT_FILENO); @@ -69,29 +69,29 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { return memory_manager->Free(memory_manager, ptr, dev); } -steady_clock::time_point g_utvm_start_time; -int g_utvm_timer_running = 0; +steady_clock::time_point g_microtvm_start_time; +int g_microtvm_timer_running = 0; tvm_crt_error_t TVMPlatformTimerStart() { - if (g_utvm_timer_running) { + if (g_microtvm_timer_running) { std::cerr << "timer already running" << std::endl; return kTvmErrorPlatformTimerBadState; } - g_utvm_start_time = std::chrono::steady_clock::now(); - g_utvm_timer_running = 1; + g_microtvm_start_time = std::chrono::steady_clock::now(); + g_microtvm_timer_running = 1; return kTvmErrorNoError; } tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_utvm_timer_running) { + if (!g_microtvm_timer_running) { std::cerr << "timer not running" << std::endl; return kTvmErrorPlatformTimerBadState; } - auto utvm_stop_time = std::chrono::steady_clock::now(); - std::chrono::microseconds time_span = - std::chrono::duration_cast(utvm_stop_time - g_utvm_start_time); + auto microtvm_stop_time = std::chrono::steady_clock::now(); + std::chrono::microseconds time_span = std::chrono::duration_cast( + microtvm_stop_time - g_microtvm_start_time); *elapsed_time_seconds = static_cast(time_span.count()) / 1e6; - g_utvm_timer_running = 0; + g_microtvm_timer_running = 0; return kTvmErrorNoError; } @@ -117,7 +117,7 @@ static char** g_argv = NULL; int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle) { execvp(g_argv[0], g_argv); - perror("utvm runtime: error restarting"); + perror("microTVM runtime: error restarting"); return -1; } @@ -130,7 +130,7 @@ int main(int argc, char** argv) { return 2; } - utvm_rpc_server_t rpc_server = UTvmRpcServerInit(&UTvmWriteFunc, nullptr); + microtvm_rpc_server_t rpc_server = MicroTVMRpcServerInit(&MicroTVMWriteFunc, nullptr); #ifdef TVM_HOST_USE_GRAPH_EXECUTOR_MODULE CHECK_EQ(TVMGraphExecutorModule_Register(), kTvmErrorNoError, @@ -140,9 +140,10 @@ int main(int argc, char** argv) { int error = TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server, 0); if (error) { - fprintf(stderr, - "utvm runtime: internal error (error#: %x) registering global packedfunc; exiting\n", - error); + fprintf( + stderr, + "microTVM runtime: internal error (error#: %x) registering global packedfunc; exiting\n", + error); return 2; } @@ -153,21 +154,21 @@ int main(int argc, char** argv) { uint8_t c; int ret_code = read(STDIN_FILENO, &c, 1); if (ret_code < 0) { - perror("utvm runtime: read failed"); + perror("microTVM runtime: read failed"); return 2; } else if (ret_code == 0) { - fprintf(stderr, "utvm runtime: 0-length read, exiting!\n"); + fprintf(stderr, "microTVM runtime: 0-length read, exiting!\n"); return 2; } uint8_t* cursor = &c; size_t bytes_to_process = 1; while (bytes_to_process > 0) { - tvm_crt_error_t err = UTvmRpcServerLoop(rpc_server, &cursor, &bytes_to_process); + tvm_crt_error_t err = MicroTVMRpcServerLoop(rpc_server, &cursor, &bytes_to_process); if (err == kTvmErrorPlatformShutdown) { break; } else if (err != kTvmErrorNoError) { char buf[1024]; - snprintf(buf, sizeof(buf), "utvm runtime: UTvmRpcServerLoop error: %08x", err); + snprintf(buf, sizeof(buf), "microTVM runtime: MicroTVMRpcServerLoop error: %08x", err); perror(buf); return 2; } diff --git a/src/runtime/crt/utvm_rpc_common/frame_buffer.cc b/src/runtime/crt/microtvm_rpc_common/frame_buffer.cc similarity index 100% rename from src/runtime/crt/utvm_rpc_common/frame_buffer.cc rename to src/runtime/crt/microtvm_rpc_common/frame_buffer.cc diff --git a/src/runtime/crt/utvm_rpc_common/framing.cc b/src/runtime/crt/microtvm_rpc_common/framing.cc similarity index 98% rename from src/runtime/crt/utvm_rpc_common/framing.cc rename to src/runtime/crt/microtvm_rpc_common/framing.cc index 857ed2a23bec..f89c6e5688c0 100644 --- a/src/runtime/crt/utvm_rpc_common/framing.cc +++ b/src/runtime/crt/microtvm_rpc_common/framing.cc @@ -34,8 +34,9 @@ // framer in its implementation. #ifdef TVM_CRT_FRAMER_ENABLE_LOGS #include -#define TVM_FRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "utvm framer: " msg " \n", ##__VA_ARGS__) -#define TVM_UNFRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "utvm unframer: " msg " \n", ##__VA_ARGS__) +#define TVM_FRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "microTVM framer: " msg " \n", ##__VA_ARGS__) +#define TVM_UNFRAMER_DEBUG_LOG(msg, ...) \ + fprintf(stderr, "microTVM unframer: " msg " \n", ##__VA_ARGS__) #else #define TVM_FRAMER_DEBUG_LOG(msg, ...) #define TVM_UNFRAMER_DEBUG_LOG(msg, ...) diff --git a/src/runtime/crt/utvm_rpc_common/session.cc b/src/runtime/crt/microtvm_rpc_common/session.cc similarity index 97% rename from src/runtime/crt/utvm_rpc_common/session.cc rename to src/runtime/crt/microtvm_rpc_common/session.cc index e1e338e42825..3570f6260cae 100644 --- a/src/runtime/crt/utvm_rpc_common/session.cc +++ b/src/runtime/crt/microtvm_rpc_common/session.cc @@ -31,7 +31,7 @@ namespace tvm { namespace runtime { namespace micro_rpc { -struct utvm_session_start_payload_t { +struct microtvm_session_start_payload_t { uint8_t version; }; @@ -85,7 +85,7 @@ tvm_crt_error_t Session::StartSession() { RegenerateNonce(); SetSessionId(local_nonce_, 0); - utvm_session_start_payload_t payload = {Session::kVersion}; + microtvm_session_start_payload_t payload = {Session::kVersion}; tvm_crt_error_t to_return = SendInternal(MessageType::kStartSessionInit, reinterpret_cast(&payload), sizeof(payload)); if (to_return == 0) { @@ -182,7 +182,7 @@ void Session::ClearReceiveBuffer() { void Session::SendSessionStartReply(const SessionHeader& header) { RegenerateNonce(); SetSessionId(InitiatorNonce(header.session_id), local_nonce_); - utvm_session_start_payload_t payload = {Session::kVersion}; + microtvm_session_start_payload_t payload = {Session::kVersion}; tvm_crt_error_t to_return = SendInternal(MessageType::kStartSessionReply, reinterpret_cast(&payload), sizeof(payload)); state_ = State::kSessionEstablished; @@ -195,7 +195,7 @@ void Session::ProcessStartSessionInit(const SessionHeader& header) { return; } - utvm_session_start_payload_t payload; + microtvm_session_start_payload_t payload; int bytes_read = receive_buffer_->Read(reinterpret_cast(&payload), sizeof(payload)); if (bytes_read != sizeof(payload)) { return; @@ -235,7 +235,7 @@ void Session::ProcessStartSessionReply(const SessionHeader& header) { return; } - utvm_session_start_payload_t payload; + microtvm_session_start_payload_t payload; int bytes_read = receive_buffer_->Read(reinterpret_cast(&payload), sizeof(payload)); if (bytes_read != sizeof(payload)) { return; diff --git a/src/runtime/crt/utvm_rpc_common/write_stream.cc b/src/runtime/crt/microtvm_rpc_common/write_stream.cc similarity index 100% rename from src/runtime/crt/utvm_rpc_common/write_stream.cc rename to src/runtime/crt/microtvm_rpc_common/write_stream.cc diff --git a/src/runtime/crt/utvm_rpc_server/rpc_server.cc b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc similarity index 93% rename from src/runtime/crt/utvm_rpc_server/rpc_server.cc rename to src/runtime/crt/microtvm_rpc_server/rpc_server.cc index 1736f98dad12..36077216b19b 100644 --- a/src/runtime/crt/utvm_rpc_server/rpc_server.cc +++ b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc @@ -18,7 +18,7 @@ */ /*! - * \file utvm_rpc_server.cc + * \file rpc_server.cc * \brief MicroTVM RPC Server */ @@ -35,13 +35,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include "../../minrpc/minrpc_server.h" #include "crt_config.h" @@ -87,7 +87,7 @@ class MicroIOHandler { namespace { // Stored as globals so that they can be used to report initialization errors. -utvm_rpc_channel_write_t g_write_func = nullptr; +microtvm_rpc_channel_write_t g_write_func = nullptr; void* g_write_func_ctx = nullptr; } // namespace @@ -109,7 +109,7 @@ class SerialWriteStream : public WriteStream { class MicroRPCServer { public: MicroRPCServer(uint8_t* receive_storage, size_t receive_storage_size_bytes, - utvm_rpc_channel_write_t write_func, void* write_func_ctx) + microtvm_rpc_channel_write_t write_func, void* write_func_ctx) : receive_buffer_{receive_storage, receive_storage_size_bytes}, framer_{&send_stream_}, session_{&framer_, &receive_buffer_, &HandleCompleteMessageCb, this}, @@ -197,9 +197,10 @@ void* operator new[](size_t count, void* ptr) noexcept { return ptr; } extern "C" { -static utvm_rpc_server_t g_rpc_server = nullptr; +static microtvm_rpc_server_t g_rpc_server = nullptr; -utvm_rpc_server_t UTvmRpcServerInit(utvm_rpc_channel_write_t write_func, void* write_func_ctx) { +microtvm_rpc_server_t MicroTVMRpcServerInit(microtvm_rpc_channel_write_t write_func, + void* write_func_ctx) { tvm::runtime::micro_rpc::g_write_func = write_func; tvm::runtime::micro_rpc::g_write_func_ctx = write_func_ctx; @@ -223,7 +224,7 @@ utvm_rpc_server_t UTvmRpcServerInit(utvm_rpc_channel_write_t write_func, void* w } auto rpc_server = new (rpc_server_memory) tvm::runtime::micro_rpc::MicroRPCServer( receive_buffer, TVM_CRT_MAX_PACKET_SIZE_BYTES, write_func, write_func_ctx); - g_rpc_server = static_cast(rpc_server); + g_rpc_server = static_cast(rpc_server); rpc_server->Initialize(); return g_rpc_server; } @@ -258,8 +259,8 @@ void TVMLogf(const char* format, ...) { } } -tvm_crt_error_t UTvmRpcServerLoop(utvm_rpc_server_t server_ptr, uint8_t** new_data, - size_t* new_data_size_bytes) { +tvm_crt_error_t MicroTVMRpcServerLoop(microtvm_rpc_server_t server_ptr, uint8_t** new_data, + size_t* new_data_size_bytes) { tvm::runtime::micro_rpc::MicroRPCServer* server = static_cast(server_ptr); return server->Loop(new_data, new_data_size_bytes); diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 7d2ef0c9367b..47a5999fdce9 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -163,6 +163,7 @@ class MetalWorkspace final : public DeviceAPI { void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + void ReinitializeStreams(); // get the global workspace static MetalWorkspace* Global(); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 1c5666dfc17f..0ef07b189a6b 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -131,6 +130,23 @@ int GetWarpSize(id dev) { } } +void MetalWorkspace::ReinitializeStreams() { + std::vector& threadStreams = MetalThreadEntry::ThreadLocal()->stream; + ICHECK_EQ(default_streams_.size(), threadStreams.size()); + for (size_t i = 0; i < default_streams_.size(); ++i) { + if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i]) + delete threadStreams[i]; + delete default_streams_[i]; + } + default_streams_.resize(devices.size()); + threadStreams.resize(devices.size()); + for (size_t i = 0; i < devices.size(); ++i) { + Stream* stream = new Stream(devices[i]); + default_streams_[i] = stream; + threadStreams[i] = stream; + } +} + void MetalWorkspace::Init() { if (initialized_) return; std::lock_guard lock(this->mutex); @@ -141,21 +157,16 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - Stream* stream = new Stream(d); - MetalThreadEntry::ThreadLocal()->stream.push_back(stream); - default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - Stream* stream = new Stream(d); - MetalThreadEntry::ThreadLocal()->stream.push_back(stream); - default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } #endif + ReinitializeStreams(); } void MetalWorkspace::SetDevice(Device dev) { @@ -193,11 +204,10 @@ int GetWarpSize(id dev) { }; } -Stream* GetStream(TVMStreamHandle stream, int device_id) { - if (stream != nullptr) - return static_cast(stream); - else - return MetalThreadEntry::ThreadLocal()->stream[device_id]; +Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) return static_cast(stream); + ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr); + return MetalThreadEntry::ThreadLocal()->stream[device_id]; } void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, @@ -206,11 +216,11 @@ int GetWarpSize(id dev) { AUTORELEASEPOOL { this->Init(); Device dev = dev_from; - Stream* s = GetStream(stream, dev.device_id); + if (dev_from.device_type == kDLCPU) dev = dev_to; + Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; } - if (dev_from.device_type == kDLCPU) dev = dev_to; id cb = s->GetCommandBuffer(); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); @@ -269,19 +279,23 @@ int GetWarpSize(id dev) { } TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; Stream* stream = new Stream(devices[dev.device_id]); return static_cast(stream); } void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { ICHECK(stream != nullptr); + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; Stream* s = static_cast(stream); + if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s) + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr; delete s; } void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { AUTORELEASEPOOL { - Stream* s = GetStream(stream, dev.device_id); + Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); [cb commit]; @@ -293,6 +307,8 @@ int GetWarpSize(id dev) { } void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; + ICHECK(stream != nullptr); MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); } @@ -337,6 +353,10 @@ int GetWarpSize(id dev) { *rv = static_cast(ptr); }); +TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { + MetalWorkspace::Global()->ReinitializeStreams(); +}); + } // namespace metal } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/standalone/utvm_graph_executor.cc b/src/runtime/micro/standalone/microtvm_graph_executor.cc similarity index 99% rename from src/runtime/micro/standalone/utvm_graph_executor.cc rename to src/runtime/micro/standalone/microtvm_graph_executor.cc index 920faa134cf5..d91d0ea74ba4 100644 --- a/src/runtime/micro/standalone/utvm_graph_executor.cc +++ b/src/runtime/micro/standalone/microtvm_graph_executor.cc @@ -17,7 +17,7 @@ * under the License. */ -#include "utvm_graph_executor.h" +#include "microtvm_graph_executor.h" #include diff --git a/src/runtime/micro/standalone/utvm_graph_executor.h b/src/runtime/micro/standalone/microtvm_graph_executor.h similarity index 94% rename from src/runtime/micro/standalone/utvm_graph_executor.h rename to src/runtime/micro/standalone/microtvm_graph_executor.h index afede6a7b30a..73aead54aaed 100644 --- a/src/runtime/micro/standalone/utvm_graph_executor.h +++ b/src/runtime/micro/standalone/microtvm_graph_executor.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_MICRO_STANDALONE_UTVM_GRAPH_EXECUTOR_H_ -#define TVM_RUNTIME_MICRO_STANDALONE_UTVM_GRAPH_EXECUTOR_H_ +#ifndef TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_GRAPH_EXECUTOR_H_ +#define TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_GRAPH_EXECUTOR_H_ #include @@ -30,8 +30,8 @@ #include #include +#include "microtvm_runtime_api.h" #include "minimal_vector.h" -#include "utvm_runtime_api.h" namespace tvm { namespace micro { @@ -164,4 +164,4 @@ class MicroGraphExecutor { } // namespace micro } // namespace tvm -#endif // TVM_RUNTIME_MICRO_STANDALONE_UTVM_GRAPH_EXECUTOR_H_ +#endif // TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_GRAPH_EXECUTOR_H_ diff --git a/src/runtime/micro/standalone/utvm_runtime.cc b/src/runtime/micro/standalone/microtvm_runtime.cc similarity index 74% rename from src/runtime/micro/standalone/utvm_runtime.cc rename to src/runtime/micro/standalone/microtvm_runtime.cc index 585da9300128..a51be1414b68 100644 --- a/src/runtime/micro/standalone/utvm_runtime.cc +++ b/src/runtime/micro/standalone/microtvm_runtime.cc @@ -16,38 +16,38 @@ * specific language governing permissions and limitations * under the License. */ -#include "tvm/runtime/micro/standalone/utvm_runtime.h" +#include "tvm/runtime/micro/standalone/microtvm_runtime.h" #include -#include "utvm_graph_executor.h" +#include "microtvm_graph_executor.h" -void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module) { +void* MicroTVMRuntimeCreate(const char* json, size_t json_len, void* module) { return new tvm::micro::MicroGraphExecutor(std::string(json, json + json_len), reinterpret_cast(module)); } -void UTVMRuntimeDestroy(void* handle) { +void MicroTVMRuntimeDestroy(void* handle) { delete reinterpret_cast(handle); } -void UTVMRuntimeSetInput(void* handle, int index, void* tensor) { +void MicroTVMRuntimeSetInput(void* handle, int index, void* tensor) { reinterpret_cast(handle)->SetInput( index, reinterpret_cast(tensor)); } -void UTVMRuntimeRun(void* handle) { +void MicroTVMRuntimeRun(void* handle) { reinterpret_cast(handle)->Run(); } -void UTVMRuntimeGetOutput(void* handle, int index, void* tensor) { +void MicroTVMRuntimeGetOutput(void* handle, int index, void* tensor) { reinterpret_cast(handle)->CopyOutputTo( index, reinterpret_cast(tensor)); } -void* UTVMRuntimeDSOModuleCreate(const char* so, size_t so_len) { +void* MicroTVMRuntimeDSOModuleCreate(const char* so, size_t so_len) { return new tvm::micro::DSOModule(std::string(so, so + so_len)); } -void UTVMRuntimeDSOModuleDestroy(void* module) { +void MicroTVMRuntimeDSOModuleDestroy(void* module) { delete reinterpret_cast(module); } diff --git a/src/runtime/micro/standalone/utvm_runtime_api.cc b/src/runtime/micro/standalone/microtvm_runtime_api.cc similarity index 98% rename from src/runtime/micro/standalone/utvm_runtime_api.cc rename to src/runtime/micro/standalone/microtvm_runtime_api.cc index a6ac420feec2..c266107faafb 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.cc +++ b/src/runtime/micro/standalone/microtvm_runtime_api.cc @@ -17,7 +17,7 @@ * under the License. */ -#include "utvm_runtime_api.h" +#include "microtvm_runtime_api.h" #include diff --git a/src/runtime/micro/standalone/utvm_runtime_api.h b/src/runtime/micro/standalone/microtvm_runtime_api.h similarity index 91% rename from src/runtime/micro/standalone/utvm_runtime_api.h rename to src/runtime/micro/standalone/microtvm_runtime_api.h index b38aa0a47a8c..47d4d80b9c09 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.h +++ b/src/runtime/micro/standalone/microtvm_runtime_api.h @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_API_H_ -#define TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_API_H_ +#ifndef TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_RUNTIME_API_H_ +#define TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_RUNTIME_API_H_ #include #include @@ -51,4 +51,4 @@ TVM_MICRO_RUNTIME_API_BACKEND_API const char* TVMGetLastError(void); #undef TVM_MICRO_RUNTIME_API_BACKEND_API -#endif // TVM_RUNTIME_MICRO_STANDALONE_UTVM_RUNTIME_API_H_ +#endif // TVM_RUNTIME_MICRO_STANDALONE_MICROTVM_RUNTIME_API_H_ diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 397f57b36dad..4040d82b33e7 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -64,8 +64,13 @@ class OpenCLWrappedFunc { } // setup arguments. for (cl_uint i = 0; i < arg_size_.size(); ++i) { - auto* arg = static_cast(void_args[i]); - OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg->buffer)); + void* arg = nullptr; + if (args.type_codes[i] == DLDataTypeCode::kDLOpaqueHandle) { + arg = static_cast(void_args[i])->buffer; + } else { + arg = void_args[i]; + } + OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg)); } cl_command_queue queue = w_->GetQueue(t->device); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); @@ -193,8 +198,8 @@ void OpenCLModuleNode::Init() { ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " << "source from code generation, but no kernel " << "delimiter was found."; - ICHECK_EQ(workspace_->num_registered_kernels, parsed_kernels_.size()) - << "The number of registered kernels does not match number of parsed kernel sources"; + ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) + << "The number of parsed kernel sources does not match the number of kernel functions"; // zero initialize cl_program pointers for each device kernel for (auto& kv : parsed_kernels_) { programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 9bb782b384dd..e83f062795e4 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -204,7 +204,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { using Stream::WriteArray; void MessageStart(uint64_t packet_nbytes) { - // Unused here, implemented for uTVM framing layer. + // Unused here, implemented for microTVM framing layer. } bool Read(RPCCode* code) { @@ -219,7 +219,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } void MessageDone() { - // Unused here, implemented for uTVM framing layer. + // Unused here, implemented for microTVM framing layer. } template diff --git a/src/runtime/thread_map.h b/src/runtime/thread_map.h new file mode 100644 index 000000000000..c3fc7e31e9bd --- /dev/null +++ b/src/runtime/thread_map.h @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_THREAD_MAP_H_ +#define TVM_RUNTIME_THREAD_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief Container to hold one value per thread + * + * Similar to thread_local, but intended for use as a non-static or + * non-block variable, such as class member variables. All member + * functions are thread-safe to call. If only the current thread's + * value is accessed, no additional synchronization is required. If + * another thread's stored values are accessed, external + * synchronization may be required. + * + * Calls that only require access to already-existing values will not + * block each other. Calls that require constructing a new value will + * block any other calls. + * + * \tparam T The object type to be held. For instantiation of + * ThreadMap and for calls to ThreadMap::Get, only a forward + * declaration is required. For calls to ThreadMap::GetOrMake, a + * full class definition is required. + */ +template +class ThreadMap { + public: + ThreadMap() {} + + /*! \brief Return the current thread's stored object, if it exists. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + const T* Get() const { return this->Get(std::this_thread::get_id()); } + + /*! \brief Return the stored object for a given thread, if it exists. + * + * \param id The thread whose object should be returned. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + const T* Get(std::thread::id id) const { + std::shared_lock lock(mutex_); + auto res = values_.find(id); + if (res == values_.end()) { + return nullptr; + } else { + return res->second.get(); + } + } + + /*! \brief Return the current thread's stored object, if it exists. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + T* Get() { return const_cast(const_cast*>(this)->Get()); } + + /*! \brief Return the stored object for a given thread, if it exists. + * + * \param id The thread whose object should be returned. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + T* Get(std::thread::id id) { + return const_cast(const_cast*>(this)->Get(id)); + } + + /*! \brief Return the current thread's stored object, making it if + * necessary. + * + * Since this method can modify the stored map, there is no + * non-const version available. + * + * \tparam Params Types of the stored object's constructor arguments + * + * \return A reference to the stored object + */ + template + T& GetOrMake(Params&&... params) { + return GetOrMake(std::this_thread::get_id(), std::forward(params)...); + } + + /*! \brief Return the stored object for a given thread, making it if + * necessary + * + * Since this method can modify the stored map, there is no + * non-const version available. + * + * \tparam Params Types of the stored object's constructor arguments + * + * \param id The thread whose object should be returned. + * + * \param params Arguments to the stored object's constructor. Only + * used if the specified thread does not currently exist in the map. + * + * \return A reference to the stored object + */ + template + T& GetOrMake(std::thread::id id, Params&&... params) { + // Try to get stored value first, which would only require shared + // access. + if (T* output = Get(id)) { + return *output; + } + + // Not in map, need exclusive lock to write + std::unique_lock lock(mutex_); + + // Check again, in case another thread got the unique lock first + // and already constructed the object. + auto res = values_.find(id); + if (res != values_.end()) { + return *res->second; + } + + // No value exists, make one and return it. + std::unique_ptr& new_val = values_[id] = + std::make_unique(std::forward(params)...); + return *new_val; + } + + /*! \brief Clears all values held by the ThreadMap + * + * Calling Clear() invalidates any pointers/references previously + * returned by Get/GetOrMake. + * + */ + void Clear() { + std::unique_lock lock(mutex_); + values_.clear(); + } + + private: + //! \brief Mutex to protect values_ + mutable std::shared_timed_mutex mutex_; + + //! \brief Map containing stored values + std::unordered_map> values_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_THREAD_MAP_H_ diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index bb088c5653f2..c282eb006f92 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -45,7 +45,7 @@ class PooledAllocator final : public Allocator { ~PooledAllocator() { ReleaseAll(); } Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { - std::lock_guard lock(mu_); + std::lock_guard lock(mu_); size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; auto&& it = memory_pool_.find(size); if (it != memory_pool_.end() && !it->second.empty()) { @@ -57,14 +57,22 @@ class PooledAllocator final : public Allocator { Buffer buf; buf.device = device_; buf.size = size; - buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + try { + buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } + used_memory_.fetch_add(size, std::memory_order_relaxed); DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; return buf; } void Free(const Buffer& buffer) override { - std::lock_guard lock(mu_); + std::lock_guard lock(mu_); if (memory_pool_.find(buffer.size) == memory_pool_.end()) { memory_pool_.emplace(buffer.size, std::vector{}); } @@ -76,7 +84,7 @@ class PooledAllocator final : public Allocator { private: void ReleaseAll() { - std::lock_guard lock(mu_); + std::lock_guard lock(mu_); for (auto const& it : memory_pool_) { auto const& pool = it.second; for (auto const& buf : pool) { @@ -92,7 +100,7 @@ class PooledAllocator final : public Allocator { size_t page_size_; std::atomic used_memory_; std::unordered_map > memory_pool_; - std::mutex mu_; + std::recursive_mutex mu_; Device device_; }; diff --git a/src/runtime/vulkan/vulkan_buffer.cc b/src/runtime/vulkan/vulkan_buffer.cc index 7059e7c617f4..ef8215c01738 100644 --- a/src/runtime/vulkan/vulkan_buffer.cc +++ b/src/runtime/vulkan/vulkan_buffer.cc @@ -19,27 +19,125 @@ #include "vulkan_buffer.h" +#include + #include "vulkan_device_api.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { namespace vulkan { -void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) { - if (buf && buf->vk_buf) { - if (buf->host_addr != nullptr) { - vkUnmapMemory(buf->device, buf->vk_buf->memory); - } - if (buf->vk_buf->memory != VK_NULL_HANDLE) { - vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr); - } - if (buf->vk_buf->buffer != VK_NULL_HANDLE) { - vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr); +VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) { + VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO}; + info.size = nbytes; + // Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to + // specify the queue families. + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + return info; +} + +VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index) + : device_(device) { + // Create a buffer + VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage); + VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer)); + + // Allocate memory + VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO}; + mem_info.allocationSize = buffer_info.size; + mem_info.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR dedicated_info = { + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR}; + + bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize); + if (use_dedicated_allocation) { + dedicated_info.buffer = buffer; + mem_info.pNext = &dedicated_info; + } + + VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory)); + + // Bind the buffer to the allocated memory + VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); +} + +VulkanBuffer::~VulkanBuffer() { + if (buffer) { + vkDestroyBuffer(device_, buffer, nullptr); + } + if (memory) { + vkFreeMemory(device_, memory, nullptr); + } +} + +VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) + : device_(other.device_), buffer(other.buffer), memory(other.memory) { + other.device_ = VK_NULL_HANDLE; + other.buffer = VK_NULL_HANDLE; + other.memory = VK_NULL_HANDLE; +} + +VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) { + std::swap(device_, other.device_); + std::swap(buffer, other.buffer); + std::swap(memory, other.memory); + return *this; +} + +bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer, + VkDeviceSize* nbytes) { + if (device.get_buffer_memory_requirements_2_functions) { + // Which buffer to request information about + VkBufferMemoryRequirementsInfo2KHR req_info2 = { + VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR}; + req_info2.buffer = buffer; + + // What information to request + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + + VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR}; + req2.pNext = &dedicated_req; + + device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + device, &req_info2, &req2); + if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) { + *nbytes = req2.memoryRequirements.size; + return true; } - buf->host_addr = nullptr; - delete buf->vk_buf; } + + return false; +} + +VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, + VkBufferUsageFlags usage, uint32_t mem_type_index) + : vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) { + VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr)); +} + +VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() { + if (host_addr) { + vkUnmapMemory(vk_buf.device_, vk_buf.memory); + } +} + +VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&& other) + : vk_buf(std::move(other.vk_buf)), host_addr(other.host_addr), size(other.size) { + other.host_addr = nullptr; + other.size = 0; +} + +VulkanHostVisibleBuffer& VulkanHostVisibleBuffer::operator=(VulkanHostVisibleBuffer&& other) { + std::swap(vk_buf, other.vk_buf); + std::swap(host_addr, other.host_addr); + std::swap(size, other.size); + + return *this; } } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/runtime/vulkan/vulkan_buffer.h index 77406ec2b2f8..a3e37431e434 100644 --- a/src/runtime/vulkan/vulkan_buffer.h +++ b/src/runtime/vulkan/vulkan_buffer.h @@ -29,20 +29,120 @@ namespace tvm { namespace runtime { namespace vulkan { -struct VulkanBuffer { +class VulkanDevice; + +class VulkanBuffer { + public: + /* \brief Allocate memory on the device + * + * \param device Which device should have the memory allocation. + * The VulkanDevice given should outlive the VulkanBuffer. + * + * \param nbytes Size of the buffer in bytes + * + * \param usage The usage flags for the buffer (e.g. transfer + * source, transfer dest, storage buffer, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + + //! \brief Destructor, deallocates the memory and buffer. + ~VulkanBuffer(); + + // Forbid copy assignment/constructor + VulkanBuffer(const VulkanBuffer&) = delete; + VulkanBuffer& operator=(const VulkanBuffer&) = delete; + + // Allow move assignment/constructor + VulkanBuffer(VulkanBuffer&&); + VulkanBuffer& operator=(VulkanBuffer&&); + + private: + /*! \brief Whether this buffer should be allocated using dedicated + * allocation + * + * In typical usage, there will be one VkDeviceMemory that has a + * large number of VkBuffers pointing to it. Currently, the TVM + * Vulkan runtime has a single VkBuffer for each VkDeviceMemory. In + * this case, there can be performance benefits by explicitly + * marking this as a dedicated allocation. The function returns + * true if the device supports the dedicated allocation extension, + * and the buffer either requires or has better performance with a + * dedicated allocation. + * + * \param[out] nbytes If using dedicated allocation, the number of + * bytes required for the allocation. If not using dedicated + * allocation, this value is unchanged. + * + * \returns Whether the allocation should use the dedicated + * allocation extension. + */ + static bool UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer, + VkDeviceSize* nbytes); + + // TODO(elunderberg): Move copy functionality into the buffer class + // so these don't need to be public. + public: + /*! \brief Pointer to the device that owns this buffer. + * + * Assumes that the VulkanBuffer will be destructed before the + * VulkanDevice, and this will never be a dangling reference. + * Stores a VkDevice and not a VulkanDevice, because the + * VulkanDevice may be moved to a different location while the + * VulkanBuffer is alive. + */ + VkDevice device_{VK_NULL_HANDLE}; + + //! \brief Handle to the logical buffer on the device VkBuffer buffer{VK_NULL_HANDLE}; + + //! \brief Handle to the physical device memory VkDeviceMemory memory{VK_NULL_HANDLE}; + + friend class VulkanHostVisibleBuffer; }; /*! \brief A struct to represent Vulkan buffers backed by host visible memory */ -struct VulkanHostVisibleBuffer { - // A device where the buffer is allocated - VkDevice device{nullptr}; - // Vulkan buffer and memory - VulkanBuffer* vk_buf{nullptr}; - // The corresponding pointer to the host memory +class VulkanHostVisibleBuffer { + public: + /* \brief Allocate memory on the device, visible to the host + * + * \param device Which GPU device should have the memory allocation. + * The VulkanDevice specified should outlive the VulkanBuffer. + * + * \param nbytes Size of the buffer in bytes + * + * \param usage The usage flags for the buffer (e.g. transfer + * source, transfer dest, storage buffer, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + + //! \brief Unmap memory and deallocate. + ~VulkanHostVisibleBuffer(); + + // Forbid copy assignment/constructor + VulkanHostVisibleBuffer(const VulkanHostVisibleBuffer&) = delete; + VulkanHostVisibleBuffer& operator=(const VulkanHostVisibleBuffer&) = delete; + + // Allow move assignment/constructor + VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&&); + VulkanHostVisibleBuffer& operator=(VulkanHostVisibleBuffer&&); + + private: + // TODO(elunderberg): Move copy functionality into the buffer class + // so these don't need to be public. + public: + VulkanBuffer vk_buf; void* host_addr{nullptr}; - // The size of the buffer in bytes size_t size{0}; }; @@ -54,8 +154,6 @@ VulkanHostVisibleBuffer* GetOrAllocate( std::unordered_map>* buffers_ptr, bool sync_before_realloc = false); -void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf); - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index e92b566e0aab..5e4be8209550 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -28,7 +28,6 @@ #include "vulkan_device.h" #include "vulkan_device_api.h" #include "vulkan_instance.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -310,6 +309,14 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ } VulkanDevice::~VulkanDevice() { + // Need to clear anything that uses this device calling + // vkDestroyDevice. Might be a sign that the VkDevice should be + // held by member variable rather than beind owned directly by + // VulkanDevice. + stream_per_thread.Clear(); + staging_buffer_per_thread.Clear(); + uniform_buffer_per_thread.Clear(); + if (device_) { vkDestroyDevice(device_, nullptr); } @@ -491,6 +498,49 @@ void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) { VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_)); } +VulkanStream& VulkanDevice::ThreadLocalStream() { + return const_cast(const_cast(this)->ThreadLocalStream()); +} + +const VulkanStream& VulkanDevice::ThreadLocalStream() const { + return stream_per_thread.GetOrMake(this); +} + +VulkanStagingBuffer& VulkanDevice::ThreadLocalStagingBuffer(size_t min_size) { + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VulkanStagingBuffer& result = + staging_buffer_per_thread.GetOrMake(*this, min_size, usage, staging_mtype_index); + + if (result.size < min_size) { + result = VulkanStagingBuffer(*this, min_size, usage, staging_mtype_index); + } + + return result; +} + +void VulkanDevice::AllocateThreadLocalUniformBuffer(size_t min_size) { + auto usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + auto buffer_info = MakeBufferCreateInfo(min_size, usage); + auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + auto mem_type_index = FindMemoryType(*this, buffer_info, prop); + + VulkanUniformBuffer& result = + uniform_buffer_per_thread.GetOrMake(*this, min_size, usage, mem_type_index); + + if (result.size < min_size) { + result = VulkanUniformBuffer(*this, min_size, usage, mem_type_index); + } +} + +VulkanStagingBuffer& VulkanDevice::ThreadLocalUniformBuffer(size_t min_size) { + VulkanStagingBuffer* buffer = uniform_buffer_per_thread.Get(); + ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated."; + ICHECK_GE(buffer->size, min_size) + << "Vulkan uniform buffer of size " << min_size << " requested, but only " << buffer->size + << " was previously allocated."; + return *buffer; +} + uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop) { VkBuffer buffer; @@ -512,115 +562,26 @@ uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, return 0; } -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, - VkBufferUsageFlags usage) { - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(device.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = usage; - return info; -} - -VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index) { - auto info = MakeBufferCreateInfo(device, nbytes, usage); - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); - - // bind to memory - bool dedicated_allocation = false; - VkMemoryRequirements2KHR req2; - - if (device.get_buffer_memory_requirements_2_functions) { - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = info.size; - minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; -} - VulkanHostVisibleBuffer* GetOrAllocate( int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, std::unordered_map>* buffers_ptr, bool sync_before_realloc) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + auto& buffers = *buffers_ptr; - if (!buffers[device_id]) { - buffers[device_id] = std::make_unique(); - } - auto& buf = *(buffers[device_id]); - if (buf.device != nullptr && buf.size < size) { - // free previous buffer - if (sync_before_realloc) { - // For the deferred execution mode, we need to make sure that old tasks that use - // the older, smaller buffer get finished - // Synchronization on staging buffers is done after host to device memory copy - // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization - // points - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); - } - DeleteHostVisibleBuffer(&buf); + bool needs_alloc = !buffers[device_id] || (buffers[device_id]->size < size); + bool is_realloc = buffers[device_id] && (buffers[device_id]->size < size); + if (is_realloc && sync_before_realloc) { + device.ThreadLocalStream().Synchronize(); } - const auto& vulkan_device = VulkanDeviceAPI::Global()->device(device_id); - - if (buf.device == nullptr) { - buf.device = vulkan_device; - } - if (buf.host_addr == nullptr) { - buf.vk_buf = CreateBuffer(vulkan_device, size, usage, mem_type_index); - VULKAN_CALL(vkMapMemory(vulkan_device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); - buf.size = size; + if (needs_alloc) { + auto new_buffer = + std::make_unique(device, size, usage, mem_type_index); + buffers[device_id] = std::move(new_buffer); } - return &buf; + return buffers[device_id].get(); } } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index b55eb8a3d9e0..045628bc9092 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -21,14 +21,19 @@ #define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ #include -#include #include +#include +#include #include +#include +#include #include +#include "../thread_map.h" #include "vulkan/vulkan_core.h" #include "vulkan_buffer.h" +#include "vulkan_stream.h" namespace tvm { namespace runtime { @@ -156,6 +161,43 @@ class VulkanDevice { */ bool HasExtension(const char* query) const; + //! \brief Return the VulkanStream for the current CPU thread + VulkanStream& ThreadLocalStream(); + + //! \brief Return the VulkanStream for the current CPU thread + const VulkanStream& ThreadLocalStream() const; + + /*! \brief Return the staging buffer for the current CPU thread + * + * This function may re-allocate the staging buffer depending on the + * size of the previously allocated buffer. + * + * \param min_size The size in bytes of the staging buffer to be + * returned. The buffer may be larger than requested, depending on + * previous use. + */ + VulkanStagingBuffer& ThreadLocalStagingBuffer(size_t min_size); + + /*! \brief Allocate the uniform buffer for the current CPU thread + * + * \param min_size The minimum size in bytes of the uniformn buffer + * to be allocated. If a larger uniform buffer has already been + * allocated, no allocation is performed. + */ + void AllocateThreadLocalUniformBuffer(size_t min_size); + + /*! \brief Return the uniform buffer for the current CPU thread + * + * Assumes that AllocateThreadLocalUniformBuffer has previously been + * called, with a min_size greater than or equal to the min_size of + * the current call. If this is not the case, will throw an + * exception. + * + * \param min_size The minimum size in bytes of the uniform buffer to be + * returned. + */ + VulkanUniformBuffer& ThreadLocalUniformBuffer(size_t min_size); + // Cached device properties, queried through Vulkan API. VulkanDeviceProperties device_properties{}; @@ -183,8 +225,24 @@ class VulkanDevice { */ void do_swap(VulkanDevice&& other); + /*! \brief Returns a queue family capable of running Vulkan compute + * operations + */ uint32_t SelectComputeQueueFamily() const; + + /*! \brief Returns the extensions to be enabled. + * + * All char* in the returned vector point to static memory + * allocations, and do not require cleanup. + */ std::vector SelectEnabledExtensions() const; + + /*! \brief Initialize the VkDevice + * + * Called during VulkanDevice construction. Assumes that + * queue_family_index, device_properties, and enabled_extensions + * have been set. + */ void CreateVkDevice(const VulkanInstance& instance); //! \brief Handle to the Vulkan API physical device @@ -207,19 +265,30 @@ class VulkanDevice { /*! \brief Handle to Vulkan API VkQueue. * * Work can be executed by submitted to this queue using - * VulkanDevice::SubmitQueue. + * VulkanDevice::QueueSubmit. */ VkQueue queue{nullptr}; + + /*! \brief The VulkanStream for each CPU thread. + * + * To mimic the semantics of cudaSetDevice and cuLaunchKernel, each + * CPU thread must have a separate stream of execution. The + * ThreadMap is declared mutable so that the streams can be lazily + * generated. + */ + mutable ThreadMap stream_per_thread; + + //! \brief The VulkanStagingBuffer for each CPU thread. + ThreadMap staging_buffer_per_thread; + + //! \brief The VulkanUniformBuffer for each CPU thread. + ThreadMap uniform_buffer_per_thread; }; uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop); -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, - VkBufferUsageFlags usage); - -VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index); +VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index bc25f25e7e12..1fede98f7211 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -26,7 +26,6 @@ #include #include "vulkan_common.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -55,7 +54,20 @@ VulkanDeviceAPI::VulkanDeviceAPI() { VulkanDeviceAPI::~VulkanDeviceAPI() {} -void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; } +void VulkanDeviceAPI::SetDevice(Device dev) { + ICHECK_EQ(dev.device_type, kDLVulkan) + << "Active vulkan device cannot be set to non-vulkan device" << dev; + + ICHECK_LE(dev.device_id, static_cast(devices_.size())) + << "Attempted to set active vulkan device to device_id==" << dev.device_id << ", but only " + << devices_.size() << " devices present"; + + active_device_id_per_thread.GetOrMake(0) = dev.device_id; +} + +int VulkanDeviceAPI::GetActiveDeviceID() { return active_device_id_per_thread.GetOrMake(0); } + +VulkanDevice& VulkanDeviceAPI::GetActiveDevice() { return device(GetActiveDeviceID()); } void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); @@ -225,7 +237,7 @@ void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignmen const auto& device = this->device(dev.device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(device, nbytes, usage, device.compute_mtype_index); + return new VulkanBuffer(device, nbytes, usage, device.compute_mtype_index); } void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { @@ -233,19 +245,20 @@ void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { // finish all the vulkan commands that reference the buffer. StreamSync(dev, nullptr); - const auto& device = this->device(dev.device_id); auto* pbuf = static_cast(ptr); - vkDestroyBuffer(device, pbuf->buffer, nullptr); - vkFreeMemory(device, pbuf->memory, nullptr); delete pbuf; } void* VulkanDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(dev, size); + auto& pool = pool_per_thread.GetOrMake(kDLVulkan, this); + return pool.AllocWorkspace(dev, size); } void VulkanDeviceAPI::FreeWorkspace(Device dev, void* data) { - VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(dev, data); + auto* pool = pool_per_thread.Get(); + ICHECK(pool) << "Attempted to free a vulkan workspace on a CPU-thread " + << "that has never allocated a workspace"; + pool->FreeWorkspace(dev, data); } TVMStreamHandle VulkanDeviceAPI::CreateStream(Device dev) { return nullptr; } @@ -263,7 +276,7 @@ void VulkanDeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { ICHECK_EQ(stream, static_cast(nullptr)); - VulkanThreadEntry::ThreadLocal()->Stream(dev.device_id)->Synchronize(); + device(dev.device_id).ThreadLocalStream().Synchronize(); } void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { @@ -282,96 +295,94 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([=](VulkanStreamState* state) { - // 1: copy - const auto* from_buf = static_cast(from); - auto* to_buf = static_cast(to); - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); - // 2: barrier(transfer-> compute|transfer) - ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Vulkan disallow cross device copy."; - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier( - state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, - &barrier_info, 0, nullptr, 0, nullptr); - }); + ICHECK_EQ(dev_from.device_id, dev_to.device_id) + << "The Vulkan runtime does not support deviceA to deviceB copies. " + << "This should be changed to a deviceA to CPU copy, followed by a CPU to deviceB copy"; + + device(dev_from.device_id).ThreadLocalStream().Launch([=](VulkanStreamState* state) { + // 1: copy + const auto* from_buf = static_cast(from); + auto* to_buf = static_cast(to); + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); + // 2: barrier(transfer-> compute|transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }); } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { const auto* from_buf = static_cast(from); - const auto& device = this->device(dev_from.device_id); - auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([&](VulkanStreamState* state) { - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = 0; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->vk_buf->buffer, 1, - ©_info); - }); - VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); + auto& device = this->device(dev_from.device_id); + auto& stream = device.ThreadLocalStream(); + auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + stream.Launch([&](VulkanStreamState* state) { + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = 0; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, staging_buffer.vk_buf.buffer, 1, + ©_info); + }); + stream.Synchronize(); if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; + mrange.memory = staging_buffer.vk_buf.memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkInvalidateMappedMemoryRanges(device, 1, &mrange)); } - memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); + memcpy(static_cast(to) + to_offset, static_cast(staging_buffer.host_addr), size); } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { - const auto& device = this->device(dev_to.device_id); + auto& device = this->device(dev_to.device_id); + auto& stream = device.ThreadLocalStream(); const auto* to_buf = static_cast(to); - VulkanStagingBuffer* temp = - VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); - memcpy(temp->host_addr, static_cast(from) + from_offset, size); + auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + memcpy(staging_buffer.host_addr, static_cast(from) + from_offset, size); // host side flush if access is not coherent. // so writes from CPU is visible to GPU if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; + mrange.memory = staging_buffer.vk_buf.memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkFlushMappedMemoryRanges(device, 1, &mrange)); } - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_to.device_id) - ->Launch([&](VulkanStreamState* state) { - // 0: barrier(host->transfer) - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = 0; - barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, - nullptr); - // 1: copy - VkBufferCopy copy_info; - copy_info.srcOffset = 0; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, temp->vk_buf->buffer, to_buf->buffer, 1, ©_info); - }); + stream.Launch([&](VulkanStreamState* state) { + // 0: barrier(host->transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = 0; + barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, + nullptr); + // 1: copy + VkBufferCopy copy_info; + copy_info.srcOffset = 0; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_buf->buffer, 1, + ©_info); + }); // TODO(tulloch): should we instead make the staging buffer a property of the // Stream? This would allow us to elide synchronizations here. - VulkanThreadEntry::ThreadLocal()->Stream(dev_to.device_id)->Synchronize(); + stream.Synchronize(); } else { LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" << ", from=" << from_dev_type << ", to=" << to_dev_type; @@ -384,6 +395,10 @@ const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { return devices_[device_id]; } +VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { + return const_cast(const_cast(this)->device(device_id)); +} + TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index cf5652a3d9c4..b8be3eb43c79 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -21,14 +21,16 @@ #define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ #include +#include #include #include +#include "../thread_map.h" +#include "../workspace_pool.h" #include "vulkan/vulkan_core.h" #include "vulkan_device.h" #include "vulkan_instance.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -69,6 +71,22 @@ class VulkanDeviceAPI final : public DeviceAPI { // End of required methods for the DeviceAPI interface public: + /*! \brief Return the currently active VulkanDevice + * + * The active device can be set using VulkanDeviceAPI::SetDevice. + * Each CPU thread has its own active device, mimicking the + * semantics of cudaSetDevice. + */ + VulkanDevice& GetActiveDevice(); + + /*! \brief Return the currently active VulkanDevice + * + * The active device can be set using VulkanDeviceAPI::SetDevice. + * Each CPU thread has its own active device, mimicking the + * semantics of cudaSetDevice. + */ + int GetActiveDeviceID(); + /*! \brief Return the VulkanDevice associated with a specific device_id * * These are constructed during VulkanDeviceAPI initialization, so @@ -76,6 +94,13 @@ class VulkanDeviceAPI final : public DeviceAPI { */ const VulkanDevice& device(size_t device_id) const; + /*! \brief Return the VulkanDevice associated with a specific device_id + * + * These are constructed during VulkanDeviceAPI initialization, so + * this function returns immediately. + */ + VulkanDevice& device(size_t device_id); + /*! \brief Returns a property to be stored in a target. * * Returns the results of feature/property queries done during the @@ -86,9 +111,33 @@ class VulkanDeviceAPI final : public DeviceAPI { private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); + /*! \brief The Vulkan API instance owned by the VulkanDeviceAPI + * + * Holds and manages VkInstance. + */ VulkanInstance instance_; - // The physical devices, have 1 to 1 mapping to devices + + /*! \brief Handles to the Vulkan devices + * + * The physical devices. These are constructed after the instance_, + * and must be destructed before the instance_. + */ std::vector devices_; + + /*! \brief One pool of device memory for each CPU thread. + * + * These allocate memory based on the devices stored in devices_. + * The memory pools must be destructed before devices_. + */ + ThreadMap pool_per_thread; + + /*! \brief The index of the active device for each CPU thread. + * + * To mimic the semantics of cudaSetDevice, each CPU thread can set + * the device on which functions should run. If unset, the active + * device defaults to device_id == 0. + */ + ThreadMap active_device_id_per_thread; }; } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc index 9784ee78503d..3eff112a6eea 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -19,6 +19,8 @@ #include "vulkan_stream.h" +#include "vulkan_device.h" + namespace tvm { namespace runtime { namespace vulkan { diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index ff02be4c5c35..fb4e447c15e1 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -26,12 +26,13 @@ #include #include "vulkan_common.h" -#include "vulkan_device.h" namespace tvm { namespace runtime { namespace vulkan { +class VulkanDevice; + class VulkanStreamState { public: VkCommandBuffer cmd_buffer_; diff --git a/src/runtime/vulkan/vulkan_thread_entry.cc b/src/runtime/vulkan/vulkan_thread_entry.cc deleted file mode 100644 index 1e2815f31146..000000000000 --- a/src/runtime/vulkan/vulkan_thread_entry.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "vulkan_thread_entry.h" - -#include "vulkan_buffer.h" -#include "vulkan_device_api.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -VulkanThreadEntry::~VulkanThreadEntry() { - // Because the thread entry refers to Device API - // The command buffer always will be destroyed before - // the instance and device get destroyed. - // The destruction need to be manually called - // to ensure the destruction order. - - pool.reset(); - streams_.clear(); - for (const auto& kv : staging_buffers_) { - DeleteHostVisibleBuffer(kv.second.get()); - } -} - -VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } - -void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { - const auto& device = VulkanDeviceAPI::Global()->device(device_id); - auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - auto info = MakeBufferCreateInfo(device, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - auto mem_type_index = FindMemoryType(device, info, prop); - GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, - &uniform_buffers_, true); -} - -VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t size) { - auto& buf = uniform_buffers_[device_id]; - ICHECK(buf); - ICHECK_GE(buf->size, size); - return buf.get(); -} - -VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - const auto& device = VulkanDeviceAPI::Global()->device(device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return GetOrAllocate(device_id, size, usage, device.staging_mtype_index, &staging_buffers_); -} - -VulkanThreadEntry::VulkanThreadEntry() - : pool(std::make_unique(static_cast(kDLVulkan), - VulkanDeviceAPI::Global())) { - device.device_id = 0; - device.device_type = static_cast(kDLVulkan); -} - -VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { - if (!streams_[device_id]) { - streams_[device_id] = std::unique_ptr( - new VulkanStream(&VulkanDeviceAPI::Global()->device(device_id))); - } - return streams_[device_id].get(); -} - -} // namespace vulkan -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_thread_entry.h b/src/runtime/vulkan/vulkan_thread_entry.h deleted file mode 100644 index cea5494823fd..000000000000 --- a/src/runtime/vulkan/vulkan_thread_entry.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ -#define TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ - -#include - -#include -#include - -#include "../workspace_pool.h" -#include "vulkan_buffer.h" -#include "vulkan_stream.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -/*! \brief Contains all per-CPU-thread resources. - */ -class VulkanThreadEntry { - public: - VulkanThreadEntry(); - static VulkanThreadEntry* ThreadLocal(); - - ~VulkanThreadEntry(); - - Device device; - std::unique_ptr pool; - VulkanStream* Stream(size_t device_id); - VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); - void AllocateUniformBuffer(int device_id, size_t size); - VulkanUniformBuffer* GetUniformBuffer(int device_id, size_t size); - - private: - //! Map from device to the VulkanStream for it - std::unordered_map> streams_; - //! Map from device to the StagingBuffer for it - std::unordered_map> staging_buffers_; - //! Map from device to the UniformBuffer associated with it - std::unordered_map> uniform_buffers_; -}; - -typedef dmlc::ThreadLocalStore VulkanThreadStore; - -} // namespace vulkan -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 86c3ffe23f7d..103b2aa7692c 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -25,7 +25,6 @@ #include "../file_utils.h" #include "vulkan_device_api.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -45,9 +44,8 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { - int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; - ICHECK_LT(device_id, kVulkanMaxNumDevice); - const auto& device = VulkanDeviceAPI::Global()->device(device_id); + int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID(); + auto& device = VulkanDeviceAPI::Global()->device(device_id); if (!scache_[device_id]) { scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } @@ -65,17 +63,16 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, } const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - CHECK(ubo->host_addr) << "The UBO host buffer is not allocated"; + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); VkDescriptorBufferInfo binfo; - binfo.buffer = ubo->vk_buf->buffer; + binfo.buffer = ubo.vk_buf.buffer; binfo.offset = 0; binfo.range = VK_WHOLE_SIZE; descriptor_buffers.push_back(binfo); } if (device.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { + device.ThreadLocalStream().Launch([&](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( @@ -83,8 +80,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, descriptor_buffers.data()); if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args, nbytes_scalars); + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + memcpy(ubo.host_addr, pack_args, nbytes_scalars); } else if (num_pack_args_ > 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), @@ -133,14 +130,16 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, }; const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, device_id](VulkanStreamState* state) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, nullptr); if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args_storage.data(), nbytes_scalars); + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + memcpy(ubo.host_addr, pack_args_storage.data(), nbytes_scalars); } else if (num_pack_args_ > 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, pack_args_storage.size() * sizeof(ArgUnion64), @@ -164,8 +163,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, for (size_t i = 0; i < descriptor_buffers.size(); ++i) { deferred_token.buffers_[i] = descriptor_buffers[i].buffer; } - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred( - deferred_initializer, deferred_kernel, deferred_token); + device.ThreadLocalStream().LaunchDeferred(deferred_initializer, deferred_kernel, deferred_token); } VulkanModuleNode::~VulkanModuleNode() { @@ -206,7 +204,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args) { - const auto& device = VulkanDeviceAPI::Global()->device(device_id); + auto& device = VulkanDeviceAPI::Global()->device(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; if (cp) { @@ -286,7 +284,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, if (pe->use_ubo) { // Use UBO instead of push constants push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); - VulkanThreadEntry::ThreadLocal()->AllocateUniformBuffer(device_id, nbytes_scalars); + device.AllocateThreadLocalUniformBuffer(nbytes_scalars); } { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index d5140677d45a..48ccefafe3c4 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -62,7 +62,7 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, md_builder_.reset(new llvm::MDBuilder(*ctx_)); // types t_void_ = llvm::Type::getVoidTy(*ctx_); - t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(); + t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace()); t_int_ = llvm::Type::getInt32Ty(*ctx_); t_char_ = llvm::Type::getInt8Ty(*ctx_); t_int8_ = llvm::Type::getInt8Ty(*ctx_); @@ -191,20 +191,10 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... - std::vector param_types; - // args - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); - // tcodes - param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); - // num_args - param_types.push_back(t_int_); - // ret_args - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); - // ret_tcodes - param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); - // resource_handle - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); + llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace()); + // args, tcodes, num_args, ret_value, ret_tcode, resource_handle + std::vector param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_}; llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false); llvm::Function* function = @@ -215,41 +205,29 @@ void CodeGenLLVM::LinkParameters(const Map params) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); builder_->SetInsertPoint(entry); - std::vector zero_index_list{llvm::ConstantInt::get(t_int32_, 0)}; - std::vector zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0), - llvm::ConstantInt::get(t_int32_, 0)}; - auto args_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[0], + + auto getArg = [function](int i) -> llvm::Argument* { +#if TVM_LLVM_VERSION >= 100 + return function->getArg(i); +#elif TVM_LLVM_VERSION >= 50 + return &function->arg_begin()[i]; #else - &(*(function->arg_begin())), + return &*std::next(function->arg_begin(), i); #endif - llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)); - llvm::Value* sid = builder_->CreateBitCast( - builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()), - builder_->CreateInBoundsGEP(args_array, zero_index_list)), - t_int64_); + }; + + llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace()); + llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p)); + + auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p); + auto ret_value = + builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace())); llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); - auto ret_types_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[4], -#else - &(*(std::next(function->arg_begin(), 4))), -#endif - llvm::ArrayType::get(t_int_, 1)->getPointerTo()); - auto retval_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[3], -#else - &(*std::next(function->arg_begin(), 3)), -#endif - llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)->getPointerTo()); llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); builder_->SetInsertPoint(default_block); - builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), - builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode); builder_->CreateRet(ConstInt32(kTvmErrorNoError)); // Add data to the global section. @@ -258,16 +236,20 @@ void CodeGenLLVM::LinkParameters(const Map params) { std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + auto dtype = tvm::runtime::DataType(kv.second->param->dtype); + size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); +#if TVM_LLVM_VERSION >= 100 + param_symbol->setAlignment(llvm::Align(align)); +#else + param_symbol->setAlignment(align); +#endif llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); - builder_->CreateStore( - builder_->CreatePointerCast(param_symbol, t_void_->getPointerTo(GetGlobalAddressSpace())), - builder_->CreateInBoundsGEP(retval_array, zero_array_index_list)); - builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), - builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); builder_->CreateRet(ConstInt32(0)); } } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index d8f0f8e90238..f8412b51edcf 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -128,6 +128,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext auto* sizeptr = extent.as(); ICHECK(sizeptr) << "SPIRV only allows constant thread group size " << " get " << extent; + ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; ICHECK_LT(ts.dim_index, 3); workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); } else { @@ -516,9 +517,13 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Must get init label after making value(to make sure they are correct) spirv::Label init_label = builder_->CurrentLabel(); spirv::Label head_label = builder_->NewLabel(); + builder_->SetName(head_label, "for_loop_head"); spirv::Label body_label = builder_->NewLabel(); + builder_->SetName(body_label, "for_loop_body"); spirv::Label continue_label = builder_->NewLabel(); + builder_->SetName(continue_label, "for_loop_continue"); spirv::Label merge_label = builder_->NewLabel(); + builder_->SetName(merge_label, "for_loop_merge"); builder_->MakeInst(spv::OpBranch, head_label); // Loop head @@ -643,9 +648,10 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { + // Will throw error if rebinding same local variable to a different extent. + analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); - analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } } else if (op->attr_key == tir::attr::storage_scope) { diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 128d60e7725a..9696043a244d 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -93,6 +93,7 @@ std::vector IRBuilder::Finalize() { data.insert(data.end(), decorate_.begin(), decorate_.end()); data.insert(data.end(), global_.begin(), global_.end()); data.insert(data.end(), func_header_.begin(), func_header_.end()); + data.insert(data.end(), function_scope_vars_.begin(), function_scope_vars_.end()); data.insert(data.end(), function_.begin(), function_.end()); return data; } @@ -309,11 +310,8 @@ Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { ICHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name); - if (workgroup_id_.id != 0) { - ib_.Add(workgroup_id_); - } - if (local_id_.id != 0) { - ib_.Add(local_id_); + for (auto& it : built_in_tbl_) { + ib_.Add(it.second); } ib_.Commit(&entry_); } @@ -350,34 +348,88 @@ Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems, } Value IRBuilder::GetWorkgroupID(uint32_t dim_index) { - if (workgroup_id_.id == 0) { - SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); - SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); - workgroup_id_ = NewValue(ptr_type, kVectorPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput) - .Commit(&global_); - this->Decorate(spv::OpDecorate, workgroup_id_, spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); - } - SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, workgroup_id_, - IntImm(t_int32_, static_cast(dim_index))); - return this->MakeValue(spv::OpLoad, t_int32_, ptr); + std::string name = "blockIdx." + std::string(1, 'x' + dim_index); + return GetBuiltInValue(spv::BuiltInWorkgroupId, dim_index, name); } Value IRBuilder::GetLocalID(uint32_t dim_index) { - if (local_id_.id == 0) { - SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); - SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); - local_id_ = NewValue(ptr_type, kVectorPtr); - ib_.Begin(spv::OpVariable).AddSeq(ptr_type, local_id_, spv::StorageClassInput).Commit(&global_); - this->Decorate(spv::OpDecorate, local_id_, spv::DecorationBuiltIn, - spv::BuiltInLocalInvocationId); - } - SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, local_id_, - UIntImm(t_int32_, static_cast(dim_index))); - return this->MakeValue(spv::OpLoad, t_int32_, ptr); + std::string name = "threadIdx." + std::string(1, 'x' + dim_index); + return GetBuiltInValue(spv::BuiltInLocalInvocationId, dim_index, name); +} + +Value IRBuilder::GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const std::string& name) { + // Returned cached value if it exists + { + auto it = built_in_values_tbl_.find({built_in, index}); + if (it != built_in_values_tbl_.end()) { + return it->second; + } + } + + DataType data_type; + DataType global_arr_type; + switch (built_in) { + case spv::BuiltInLocalInvocationId: + case spv::BuiltInWorkgroupId: + data_type = DataType::Int(32); + global_arr_type = data_type.with_lanes(3); + break; + + default: + LOG(FATAL) << "No data type defined for SPIR-V Built-In " << built_in; + } + + // Look up the decorated array value at global scope. If it doesn't + // exist already, declare it. + Value global_array; + { + auto it = built_in_tbl_.find(built_in); + if (it != built_in_tbl_.end()) { + global_array = it->second; + } else { + SType ptr_arr_type = this->GetPointerType(GetSType(global_arr_type), spv::StorageClassInput); + global_array = NewValue(ptr_arr_type, kVectorPtr); + + ib_.Begin(spv::OpVariable) + .AddSeq(ptr_arr_type, global_array, spv::StorageClassInput) + .Commit(&global_); + this->Decorate(spv::OpDecorate, global_array, spv::DecorationBuiltIn, built_in); + + switch (built_in) { + case spv::BuiltInLocalInvocationId: + SetName(global_array, "BuiltInLocalInvocationId"); + break; + case spv::BuiltInWorkgroupId: + SetName(global_array, "BuiltInWorkgroupId"); + break; + + default: + break; + } + + built_in_tbl_[built_in] = global_array; + } + } + + // Declare the dereferenced value + SType data_stype = GetSType(data_type); + SType ptr_type = this->GetPointerType(data_stype, spv::StorageClassInput); + Value global_const_index = UIntImm(t_int32_, static_cast(index)); + + Value ptr = NewValue(ptr_type, kNormal); + ib_.Begin(spv::OpAccessChain) + .AddSeq(ptr_type, ptr, global_array, global_const_index) + .Commit(&function_scope_vars_); + + Value output = NewValue(data_stype, kNormal); + ib_.Begin(spv::OpLoad).AddSeq(data_stype, output, ptr).Commit(&function_scope_vars_); + if (name.size()) { + SetName(output, name); + } + + // Store to cache and return + built_in_values_tbl_[{built_in, index}] = output; + return output; } Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 959ed294640e..3e19b98100c0 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -334,6 +334,18 @@ class IRBuilder { void Debug(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&debug_); } + + /*! + * \brief Set the name of a value or label + * \param obj The object to be named + * \param name The name of the object + * \tparams Obj The type of the object being named. Typically a Label or Value. + */ + template + void SetName(Obj&& obj, const std::string& name) { + Debug(spv::OpName, std::forward(obj), name); + } + /*! * \brief Add Execution mode to a function. * \param func The function value @@ -362,7 +374,7 @@ class IRBuilder { */ template void DeclareGlobal(spv::Op op, Args&&... args) { - ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); + ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&global_); } /*! * \brief Make a new instruction and append it to end of function segment. @@ -583,6 +595,20 @@ class IRBuilder { return val; } + /*! \brief Get a built-in value provided by SPIR-V + * + * \param built_in The SPIR-V built-in array to access. For + * example, spv::BuiltInLocalInvocationId to access the thread + * id. + * + * \param index The index of the built-in array to access. + * + * \param name The name of the value being accessed. For + * example, "threadIdx.x". This is for debug purposes, and is + * used to tag the variable with OpName. + */ + Value GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const std::string& name = ""); + /*! * \brief The common function to declare push constants or uniform buffer * \param value_types The values in the push constants or uniform buffer @@ -630,8 +656,32 @@ class IRBuilder { SType t_bool_, t_int32_, t_uint32_, t_fp32_, t_void_, t_void_func_; /*! \brief quick cache for const one i32 */ Value const_i32_zero_; - /*! \brief cache value for workgroup_id, local_id */ - Value workgroup_id_, local_id_; + + /*! \brief The cached values for built-in arrays + * + * Maps from a tuple of spv::BuiltIn enum to the Value containing + * that built-in array. For example, + * ``built_in_tbl_[spv::BuiltInLocalInvocationId]`` is the array + * of invocation ids, equivalent to an array of ``threadIdx.x``, + * ``threadIdx.y``, and ``threadIdx.z`` in CUDA. + * + * These are declared in the global section of the shader. + */ + std::unordered_map built_in_tbl_; + + /*! \brief The cached values for built-in values + * + * Maps from a tuple of (spv::BuiltIn enum, index) to the value + * stored at that index of the built-in array. For example, + * ``built_in_tbl_[{spv::BuiltInLocalInvocationId, 0}]`` is the + * first index of the invocation id, equivalent to + * ``threadIdx.x`` in CUDA. + * + * These are declared in the first block of the function, in the + * ``function_scope_vars_`` section. + */ + std::map, Value> built_in_values_tbl_; + /*! \brief whether push constant is defined */ Value push_const_; /*! \brief map from type code to the type */ @@ -667,8 +717,21 @@ class IRBuilder { std::vector decorate_; /*! \brief Global segment: types, variables, types */ std::vector global_; - /*! \brief Function header segment */ + /*! \brief Function header segment + * + * Contains the start of function (spv::OpFunction), first label + * (spv::OpLabel), and all array allocations (spv::OpVariable). + */ std::vector func_header_; + /*! \brief Function-scope variable declarations + * + * Contains variable declarations that should be accessible + * throughout the entire kernel (e.g. threadIdx.x). This must be + * separate from func_header_, because the function-level + * spv::OpVariable declarations must come first in the first block + * of a function. + */ + std::vector function_scope_vars_; /*! \brief Function segment */ std::vector function_; }; diff --git a/src/target/target.cc b/src/target/target.cc index 396e264ede4d..546a3596297a 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include "../runtime/object_internal.h" @@ -210,7 +211,17 @@ ObjectRef TargetInternal::ParseType(const std::string& str, // Parsing integer int v; if (!(is >> v)) { - throw Error(": Cannot parse into type \"Integer\" from string: " + str); + std::string lower(str.size(), '\x0'); + std::transform(str.begin(), str.end(), lower.begin(), + [](unsigned char c) { return std::tolower(c); }); + // Bool is a subclass of IntImm, so allow textual boolean values. + if (lower == "true") { + v = 1; + } else if (lower == "false") { + v = 0; + } else { + throw Error(": Cannot parse into type \"Integer\" from string: " + str); + } } return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc deleted file mode 100644 index 951bd6c18706..000000000000 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ /dev/null @@ -1,1124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file schedule_postproc_rewrite_for_tensor_core.cc - * - * \brief Rewrite the Stmt generated by ScheduleOps - * to accomondate tensorcore. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace te { - -using namespace te; -using runtime::StorageRank; -using runtime::StorageScope; -using runtime::ThreadScope; - -struct Tile { - int m{-1}; - int n{-1}; - int k{-1}; -}; - -std::string simplify_name(std::string input) { - auto pos = input.find("."); - if (pos != std::string::npos) { - return input.substr(0, pos); - } else { - return input; - } -} - -PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { - auto cast = input.as(); - if (cast == nullptr) { - return input; - } else if (cast->dtype == target_type) { - return cast->value; - } - return PrimExpr(); -} - -// MMAMatcher matches C = Cast(A)*Cast(B)+C, -// where A & B are fp16/int8 local buffers, -// and C is fp32/int32 local buffer. -class MMAMatcher : public StmtVisitor { - public: - explicit MMAMatcher(Map extern_buffer) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::pragma_tensor_core) { - tensor_core_on_ = true; - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else { - StmtVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtVisitor::VisitStmt_(op); - auto it = buf_map_.find(Downcast(op->producer)); - if (it == buf_map_.end()) { - return; - } - const BufferInfo& bi = it->second; - if (bi.released) { - return; - } - if (tensor_core_on_ && mma_sync_match_(op, bi)) { - matched_ = true; - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - if (!buf_map_.at(key).external) { - return; - } - this->VisitStmt(op->body); - } else { - BufferInfo bi; - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - inline bool Matched() const { return matched_; } - - friend class ScheduleAnalyser; - friend class BufferAnalyser; - - private: - struct BufferInfo { - std::string name; - DataType dtype; - bool external{false}; - bool released{false}; - bool same_as(const BufferInfo& bi) { - if (this->dtype != bi.dtype) return false; - if (this->name != bi.name) return false; - if (this->external != bi.external) return false; - if (this->released != bi.released) return false; - return true; - } - }; - - // Check whether the storage scope is local - bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) { - auto tensor = Downcast(op->producer); - auto it = storage_scope_.find(tensor.get()); - if (it == storage_scope_.end()) { - return false; - } - const std::string& strkey = it->second; - if (strkey != "local") { - return false; - } - auto it1 = buf_map_.find(tensor); - if (it1 == buf_map_.end()) { - return false; - } - *bi = it1->second; - if (bi->released) { - return false; - } - return true; - } - - // Do the pattern matching - bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) { - auto* add = op->value.as(); - if (add == nullptr) { - return false; - } - - auto* load_c = add->a.as(); - BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || - !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { - return false; - } - - auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); - if (mul == nullptr) { - return false; - } - - auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); - auto load_a = load_a_expr.as(); - BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) || - !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); - auto load_b = load_b_expr.as(); - BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) || - !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { - return false; - } - - frag_reg_.insert(buffer_c.name); - frag_reg_.insert(buffer_a.name); - frag_reg_.insert(buffer_b.name); - buf_name_.insert(std::make_pair(load_a, buffer_a.name)); - buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); - - return true; - } - - std::unordered_map buf_map_; - std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; - std::unordered_set frag_reg_; - bool matched_{false}; - bool tensor_core_on_{false}; -}; - -// BodyVisitor visits the body stmt of original ComputeOp -// to get the access indices of input matrices, -// if it is recognized as matrix multiply. -class BodyVisitor : public StmtExprVisitor { - public: - BodyVisitor() {} - - void VisitExpr_(const ReduceNode* op) final { - auto* comm_add = op->combiner->result[0].as(); - if (comm_add == nullptr || op->combiner->result.size() > 1) { - return; - } - for (PrimExpr source : op->source) { - auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); - auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); - if (mul_0 == nullptr && mul_1 == nullptr) { - continue; - } - - tensorcore_candidate_ = true; - StmtExprVisitor::VisitExpr(source); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); - } - - friend class ScheduleAnalyser; - - private: - std::unordered_map> args_; - bool tensorcore_candidate_{false}; -}; - -// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major -class ScheduleAnalyser { - public: - explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} - - bool MatrixIdentify(Schedule schedule) { - // TODO(minmin): handle the case where MatMul is not the output stage - for (Operation output : schedule->outputs) { - const ComputeOpNode* compute = output.as(); - if (compute == nullptr) { - // Not a ComputeOp - continue; - } - auto axis = compute->axis; - auto reduce_axis = compute->reduce_axis; - if (axis.size() < 2 || reduce_axis.size() != 1) { - continue; - } - const VarNode* axis_var[2]; - const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size() - 2]->var.as(); - axis_var[1] = axis[axis.size() - 1]->var.as(); - reduce_axis_var = reduce_axis[0]->var.as(); - - BodyVisitor body_visitor; - for (PrimExpr expr : compute->body) { - body_visitor(expr); - } - if (!body_visitor.tensorcore_candidate_) { - continue; - } - for (auto iter : body_visitor.args_) { - auto name = iter.first; - auto args = iter.second; - if (args.size() < 2) { - continue; - } - const VarNode* var0 = args[args.size() - 2].as(); - const VarNode* var1 = args[args.size() - 1].as(); - if (var0 == nullptr || var1 == nullptr) { - continue; - } - std::string matrix_abc, major; - if (var0 == reduce_axis_var && var1 == axis_var[1]) { - matrix_abc = "matrix_a"; - major = "col_major"; - } else if (var0 == reduce_axis_var && var1 == axis_var[0]) { - matrix_abc = "matrix_b"; - major = "row_major"; - } else if (var0 == axis_var[1] && var1 == reduce_axis_var) { - matrix_abc = "matrix_a"; - major = "row_major"; - } else if (var0 == axis_var[0] && var1 == reduce_axis_var) { - matrix_abc = "matrix_b"; - major = "col_major"; - } - matrix_abc_.insert(std::make_pair(name, matrix_abc)); - matrix_major_.insert(std::make_pair(name, major)); - } - matrix_abc_.insert(std::make_pair(compute->name, "accumulator")); - matrix_major_.insert(std::make_pair(compute->name, "col_major")); - } - - for (auto& mma_sync : mma_sync_) { - auto& operands = mma_sync.second; - auto* load_a = operands[0].as(); - auto* load_b = operands[1].as(); - auto input0 = simplify_name(buf_name_.find(load_a)->second); - auto input1 = simplify_name(buf_name_.find(load_b)->second); - auto it0 = matrix_abc_.find(input0); - auto it1 = matrix_abc_.find(input1); - - if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) { - return false; - } - if (it0->second == "matrix_a" && it1->second == "matrix_b") { - return true; - } else if (it0->second == "matrix_b" && it1->second == "matrix_a") { - mma_sync.second = Array{operands[1], operands[0], operands[2]}; - } else { - return false; - } - } - return true; - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map buf_name_; -}; - -// IndexVisitor visits access index of fragment -// to record variable for loop scaling -class IndexVisitor : public StmtExprVisitor { - public: - IndexVisitor() {} - - void VisitExpr_(const VarNode* op) final { - loop_scaling_.insert(std::make_pair(op, scaling_factor_)); - } - - friend class BufferAnalyser; - friend class TensorCoreIRMutator; - - private: - std::unordered_map loop_scaling_; - unsigned scaling_factor_{0}; -}; - -// BufferAnalyser gets buffer info, -// e.g. thread tile and warp tile, for TensorCore CodeGen -class BufferAnalyser : public StmtExprVisitor { - public: - explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - frag_reg_(mma_matcher.frag_reg_) { - for (auto kv : extern_buffer) { - BufferInfo bi; - bi.name = kv.second->name; - bi.dtype = kv.second->dtype; - bi.strides = kv.second->strides; - bi.shape = kv.second->shape; - bi.external = true; - buf_map_[kv.first] = bi; - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { - if (const IntImmNode* value = op->value.as()) { - thread_extent_.insert( - std::make_pair(op->node.as()->var->name_hint, value->value)); - } - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == tir::attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - this->VisitStmt(op->body); - } else if (op->attr_key == tir::attr::buffer_dim_align) { - te::Tensor tensor = Downcast(op->node); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto& vinfo = dim_align_[tensor]; - size_t dim = tuple->args[0].as()->value; - if (dim >= vinfo.size()) { - vinfo.resize(dim + 1); - } - vinfo[dim].align_factor = tuple->args[1].as()->value; - vinfo[dim].align_offset = tuple->args[2].as()->value; - this->VisitStmt(op->body); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ProducerStoreNode* op) final { - StmtExprVisitor::VisitStmt_(op); - auto key = Downcast(op->producer); - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(key->GetNameHint())) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(key->GetNameHint(), strides)); - - if (frag_reg_.count(bi.name)) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_load_.insert(std::make_pair(op, dst)); - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - std::vector tile_size; - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - tile_size.push_back(shape->value); - index_visitor.scaling_factor_ = shape->value; - } else { - invalid_ = true; - return; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - - std::string input_name = simplify_name(bi.name); - auto it = matrix_abc_.find(input_name); - auto it2 = matrix_major_.find(input_name); - bool ret = true; - if (it != matrix_abc_.end() && it2 != matrix_major_.end()) { - if (it->second == "matrix_a" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.m, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - ret &= assign_or_check_(&thread_tile_.n, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); - } - if (it->second == "accumulator") { - ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); - ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); - } - if (!ret) { - invalid_ = true; - return; - } - } - } - - const ProducerLoadNode* value = op->value.as(); - // TODO(tvm-team): string matching is dangerous, consider other means. - if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { - PrimExpr dst = ProducerLoad(op->producer, op->indices); - frag_store_.insert(std::make_pair(op, dst)); - } - } - - void VisitExpr_(const ProducerLoadNode* op) final { - StmtExprVisitor::VisitExpr_(op); - - auto tensor = Downcast(op->producer); - auto it = buf_map_.find(tensor); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint(); - const BufferInfo& bi = it->second; - ICHECK(!bi.released) << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(tensor->op->name)) { - if (bi.shape.size() < 2) { - invalid_ = true; - return; - } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } - } - - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul(stride, bi.shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - } - strides_.insert(std::make_pair(tensor->GetNameHint(), strides)); - - if (!frag_reg_.count(bi.name)) { - return; - } - - auto rel_index = bi.RelIndex(op->indices); - if (op->indices.size() < 2) { - invalid_ = true; - return; - } - for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - index_visitor.scaling_factor_ = shape->value; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); - } - } - - void VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external); - this->VisitStmt(op->body); - } else { - // create a buffer entry - BufferInfo bi; - - bi.bounds = op->bounds; - Array shape; - for (auto r : bi.bounds) { - shape.push_back(r->extent); - } - - Array strides; - if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; - const std::vector& avec = dim_align_[key]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = analyzer_.Simplify(stride); - } - rstrides.push_back(stride); - stride = stride * shape[dim]; - } - strides = Array(rstrides.rbegin(), rstrides.rend()); - } - - bi.name = key->GetNameHint(); - bi.dtype = key->dtype; - bi.strides = strides; - bi.shape = shape; - - buf_map_[key] = bi; - this->VisitStmt(op->body); - buf_map_[key].released = true; - } - } - - // Derive warp tile from thread tile, - // and check whether it is qualified for TensorCore. - bool QualifiedForTensorCore() { - if (invalid_) { - return false; - } - auto itx = thread_extent_.find("threadIdx.x"); - if (itx == thread_extent_.end()) { - return false; - } - int warp_threads_x = itx->second; - warp_tile_.m = warp_threads_x * thread_tile_.m; - warp_threads_y_ = 32 / warp_threads_x; - auto ity = thread_extent_.find("threadIdx.y"); - if (ity == thread_extent_.end()) { - return false; - } - if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) { - return false; - } - warp_tile_.n = warp_threads_y_ * thread_tile_.n; - warp_tile_.k = thread_tile_.k; - return supported_warp_tile_(); - } - - friend class TensorCoreIRMutator; - - private: - struct DimAlignInfo { - int align_factor{0}; - int align_offset{0}; - }; - - struct BufferInfo { - std::string name; - DataType dtype; - Array strides; - Array shape; - Region bounds; - bool external{false}; - bool released{false}; - inline Array RelIndex(Array args) const { - if (bounds.size() != 0) { - Array index; - ICHECK_EQ(bounds.size(), args.size()); - for (size_t i = 0; i < bounds.size(); ++i) { - index.push_back(args[i] - bounds[i]->min); - } - return index; - } else { - return args; - } - } - }; - - bool assign_or_check_(int* dst, int src) { - if (*dst <= 0) { - *dst = src; - return true; - } - if (*dst == src) { - return true; - } - return false; - } - - bool supported_warp_tile_() { - if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { - return true; - } - if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { - return true; - } - - return false; - } - - std::unordered_map buf_map_; - std::unordered_map> dim_align_; - std::unordered_map storage_scope_; - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_set frag_reg_; - std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map thread_extent_; - IndexVisitor index_visitor; - Tile warp_tile_; - Tile thread_tile_; - arith::Analyzer analyzer_; - int warp_threads_y_{-1}; - bool invalid_{false}; -}; - -// ThreadIdxMutator does the thread index unification inside a warp -class ThreadIdxMutator : public StmtExprMutator { - public: - explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr) { - if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImm(DataType::Int(32), 0); - return zero; - } - if (op->name_hint == "threadIdx.y") { - PrimExpr div = Div(expr, warp_y_); - PrimExpr mul = Mul(div, warp_y_); - return mul; - } - } - return expr; - } - - private: - PrimExpr warp_y_; -}; - -// TensorCoreIRMutator mutates the AST for TensorCore CodeGen -// based on tensor core intrinsics -class TensorCoreIRMutator : public StmtExprMutator { - public: - explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, - const BufferAnalyser& buffer_analyser) - : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} - - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - auto key = Downcast(op->producer); - bounds_[key] = op->bounds; - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - if (!frag_reg_.count(key->GetNameHint())) { - return stmt; - } - - auto new_extents = get_tile_size_(simplify_name(key->GetNameHint())); - - Region new_bounds; - for (size_t i = 0; i < op->bounds.size() - 2; ++i) { - new_bounds.push_back(op->bounds[i]); - } - ICHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint(); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back( - Range::FromMinExtent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return ProducerRealize(op->producer, new_bounds, op->condition, op->body); - } - return stmt; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (op->attr_key == tir::attr::realize_scope) { - auto node = op->node.as(); - if (node != nullptr) { - if (!frag_reg_.count(node->name)) { - return stmt; - } - - auto it = matrix_abc_.find(simplify_name(node->name)); - ICHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); - Stmt body = this->VisitStmt(op->body); - return AttrStmt(op->node, op->attr_key, matrix_abc, body); - } - } - return stmt; - } - - Stmt VisitStmt_(const ProducerStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - auto it = mma_sync_.find(op); - if (it != mma_sync_.end()) { - const auto& operands = it->second; - PrimExpr a = operands[0]; - auto ca = a.as(); - PrimExpr b = operands[1]; - auto cb = b.as(); - PrimExpr c = operands[2]; - auto cc = c.as(); - - ObjectPtr buffer_node_a = make_object(); - ObjectPtr buffer_node_b = make_object(); - ObjectPtr buffer_node_c = make_object(); - - auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_bmma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } else { - return Evaluate( - Call(DataType::Handle(), builtin::tvm_mma_sync(), - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset})); - } - }; - - auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call); - }; - - auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c); - }; - - return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b); - } - - auto it2 = frag_load_.find(op); - if (it2 != frag_load_.end()) { - PrimExpr dst = it2->second; - if (op->value.as() != nullptr || op->value.as() != nullptr) { - auto pload = dst.as(); - - auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); - } - - const ProducerLoadNode* value = op->value.as(); - ICHECK(value != nullptr) << "Can only load fragment from a buffer"; - - auto it = strides_.find(value->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - PrimExpr mutated_value = thread_idx_mutator(op->value); - // TODO(tvm-team) The extern function name seems to be a hack. - PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}); - - auto pload = dst.as(); - PrimExpr matrix_major; - auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint())); - ICHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << pload->producer->GetNameHint(); - if (iter2->second == "col_major") { - matrix_major = StringImm("col_major"); - } else if (iter2->second == "row_major") { - matrix_major = StringImm("row_major"); - } else { - LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); - } - - auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call); - } - - auto it3 = frag_store_.find(op); - if (it3 != frag_store_.end()) { - auto it = strides_.find(op->producer->GetNameHint()); - ICHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint(); - auto strides = it->second; - ICHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size() - 2]; - - PrimExpr dst = it3->second; - // thread index unification inside a warp - PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); - ThreadIdxMutator thread_idx_mutator(warp_y); - dst = thread_idx_mutator(dst); - dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}); - - auto pload = op->value.as(); - - auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImm("col_major")})); - }; - - ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call); - } - - return stmt; - } - - Stmt VisitStmt_(const ForNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - if (op != nullptr) { - auto it = loop_scaling_.find(op->loop_var.get()); - if (it != loop_scaling_.end()) { - int scale_factor = it->second; - int scaled_extent_value = 1; - if (const IntImmNode* ori_extent = op->extent.as()) { - int ori_extent_value = ori_extent->value; - scaled_extent_value = ori_extent_value / scale_factor; - } - PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, - op->annotations); - } - } - return stmt; - } - - private: - Array get_tile_size_(const std::string& name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - ICHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; - } - - Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, - const ObjectPtr& buffer_node, - const std::function& call_back) { - auto tensor = Downcast(pload->producer); - auto it = bounds_.find(tensor); - ICHECK(it != bounds_.end()); - Array min_bound; - for (auto i : it->second) { - min_bound.push_back(i->min); - } - - ICHECK_GE(it->second.size(), 2); - Array shape; - for (size_t i = 0; i < it->second.size() - 2; ++i) { - shape.push_back(it->second[i]->extent); - } - auto tile_size = get_tile_size_(simplify_name(tensor->op->name)); - shape.push_back(tile_size[0]); - shape.push_back(tile_size[1]); - - Array strides; - for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = shape.size() - 1; j >= i; --j) { - stride = Mul(stride, shape[j]); - } - strides.push_back(stride); - } - strides.push_back(make_const(DataType::Int(32), 1)); - - PrimExpr elem_offset = IntImm(DataType::Int(32), 0); - ICHECK_EQ(pload->indices.size(), min_bound.size()); - for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i]))); - } - - auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); - ICHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name; - buffer_node->data = Var(tensor->op->name, DataType::Handle()); - buffer_node->name = tensor->op->name; - buffer_node->scope = "wmma." + it2->second; - buffer_node->dtype = tensor->dtype; - buffer_node->strides = strides; - buffer_node->shape = shape; - buffer_node->data_alignment = 1; - buffer_node->elem_offset = analyzer_.Simplify(elem_offset); - buffer_node->offset_factor = 1; - Buffer buffer(buffer_node); - - Array args; - for (size_t i = 0; i < pload->indices.size(); ++i) { - args.push_back(pload->indices[i]); - args.push_back(shape[i]); - } - auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args); - Array node = {buffer, tensor}; - return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); - } - - std::unordered_map matrix_abc_; - std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map> strides_; - std::unordered_set frag_reg_; - std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map bounds_; - arith::Analyzer analyzer_; - Tile warp_tile_; - int warp_threads_y_{-1}; -}; - -Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, - Map extern_buffer) { - // Check if current lower target is CUDA - auto target = tvm::Target::Current(true); - if (target.defined() && target->kind->name != "cuda") { - return stmt; - } - - // Check if current runtime support GPU CUDA - Device dev{kDLCUDA, 0}; - auto api = tvm::runtime::DeviceAPI::Get(dev, true); - if (api == nullptr) { - return stmt; - } - - MMAMatcher mma_matcher(extern_buffer); - mma_matcher(stmt); - if (!mma_matcher.Matched()) { - return stmt; - } - - ScheduleAnalyser schedule_analyser(mma_matcher); - if (!schedule_analyser.MatrixIdentify(schedule)) { - return stmt; - } - - BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); - buffer_analyser(stmt); - if (!buffer_analyser.QualifiedForTensorCore()) { - return stmt; - } - - return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); -} - -TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") - .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { - return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); - }); - -} // namespace te -} // namespace tvm diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 8cc5c4bc0a3a..204a824f9248 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = lower(s, args, "func", binds); + auto lowered = LowerSchedule(s, args, "func", binds); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -116,8 +116,8 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = lower(s1, args1, "elemwise_add", binds); - auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/microtvm_runtime_standalone_test.cc similarity index 86% rename from tests/cpp/utvm_runtime_standalone_test.cc rename to tests/cpp/microtvm_runtime_standalone_test.cc index e674c3b74144..0da88cfe64e5 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/microtvm_runtime_standalone_test.cc @@ -38,7 +38,7 @@ #include #include #include -#include +#include #include #include #include @@ -107,23 +107,23 @@ TEST(MicroStandaloneRuntime, BuildModule) { const auto ret = system(ss.c_str()); ASSERT_EQ(ret, 0); // Now, execute the minimal runtime. - auto* dsoModule = UTVMRuntimeDSOModuleCreate(so_fname.c_str(), so_fname.size()); + auto* dsoModule = MicroTVMRuntimeDSOModuleCreate(so_fname.c_str(), so_fname.size()); ASSERT_NE(dsoModule, nullptr); - auto* handle = UTVMRuntimeCreate(json.c_str(), json.size(), dsoModule); + auto* handle = MicroTVMRuntimeCreate(json.c_str(), json.size(), dsoModule); ASSERT_NE(handle, nullptr); - UTVMRuntimeSetInput(handle, 0, const_cast(A.operator->())); - UTVMRuntimeSetInput(handle, 1, const_cast(B.operator->())); - UTVMRuntimeSetInput(handle, 2, const_cast(C.operator->())); - UTVMRuntimeRun(handle); + MicroTVMRuntimeSetInput(handle, 0, const_cast(A.operator->())); + MicroTVMRuntimeSetInput(handle, 1, const_cast(B.operator->())); + MicroTVMRuntimeSetInput(handle, 2, const_cast(C.operator->())); + MicroTVMRuntimeRun(handle); auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - UTVMRuntimeGetOutput(handle, 0, const_cast(Y.operator->())); + MicroTVMRuntimeGetOutput(handle, 0, const_cast(Y.operator->())); auto* pY = (float*)Y->data; for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); } - UTVMRuntimeDestroy(handle); - UTVMRuntimeDSOModuleDestroy(dsoModule); + MicroTVMRuntimeDestroy(handle); + MicroTVMRuntimeDSOModuleDestroy(dsoModule); } #endif diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 96bcdfe5d86d..cf9447bc2b10 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -334,8 +334,8 @@ def check_result( tvm.testing.assert_allclose(out.numpy(), results[idx], rtol=TOL, atol=TOL) -def test_byoc_utvm(platform, west_cmd, skip_build, tvm_debug): - """This is a simple test case to check BYOC capabilities of uTVM""" +def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): + """This is a simple test case to check BYOC capabilities of microTVM""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} x = relay.var("x", shape=(10, 10)) diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index dc277c245078..afdbdc590de0 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -118,6 +118,8 @@ def _create_header_file(tensor_name, npy_data, output_path): header_file.write(f"uint8_t {tensor_name}[] = ") elif npy_data.dtype == "float32": header_file.write(f"float {tensor_name}[] = ") + else: + raise ValueError("Data type not expected.") header_file.write("{") for i in np.ndindex(npy_data.shape): @@ -211,5 +213,46 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): assert result == 8 +def test_qemu_make_fail(platform, west_cmd, skip_build, tvm_debug): + """Testing QEMU make fail.""" + model, zephyr_board = PLATFORMS[platform] + build_config = {"skip_build": skip_build, "debug": tvm_debug} + shape = (10,) + dtype = "float32" + + this_dir = pathlib.Path(__file__).parent + tvm_source_dir = this_dir / ".." / ".." / ".." + runtime_path = tvm_source_dir / "apps" / "microtvm" / "zephyr" / "aot_demo" + + # Construct Relay program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + xx = relay.multiply(x, x) + z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) + func = relay.Function([x], z) + + target = tvm.target.target.micro(model, options=["-link-params=1", "--executor=aot"]) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lowered = relay.build(func, target) + + # Generate input/output header files + model_files_path = os.path.join(runtime_path, "include") + _create_header_file((f"input_data"), np.zeros(shape=shape, dtype=dtype), model_files_path) + _create_header_file("output_data", np.zeros(shape=shape, dtype=dtype), model_files_path) + + session_kw = _build_session_kw( + model, target, zephyr_board, west_cmd, lowered.lib, runtime_path, build_config + ) + + file_path = os.path.join(session_kw["binary"].base_dir, "zephyr/CMakeFiles/run.dir/build.make") + assert os.path.isfile(file_path), f"[{file_path}] does not exist." + + # Remove a file to create make failure. + os.remove(file_path) + transport = session_kw["flasher"].flash(session_kw["binary"]) + with pytest.raises(RuntimeError) as excinfo: + transport.open() + assert "QEMU setup failed" in str(excinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index bb44b79078dd..8fcafe489cb9 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -56,7 +56,10 @@ def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions, def _get_tflite_model(tflite_model_path, inputs_dict): """Convert TFlite graph to relay.""" - import tflite.Model + try: + import tflite.Model + except ImportError: + pytest.skip("Missing Tflite support") with open(tflite_model_path, "rb") as f: tflite_model_buffer = f.read() @@ -92,7 +95,10 @@ def test_vgg16(): device = Device() def get_model(): - from keras.applications import VGG16 + try: + from keras.applications import VGG16 + except ImportError: + pytest.skip("Missing Keras Package") vgg16 = VGG16(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000) inputs = {vgg16.input_names[0]: ((1, 224, 224, 3), "float32")} @@ -113,7 +119,10 @@ def test_mobilenet(): device = Device() def get_model(): - from keras.applications import MobileNet + try: + from keras.applications import MobileNet + except ImportError: + pytest.skip("Missing keras module") mobilenet = MobileNet( include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000 @@ -133,7 +142,10 @@ def test_quantized_mobilenet(): if skip_runtime_test(): return - import tvm.relay.testing.tf as tf_testing + try: + import tvm.relay.testing.tf as tf_testing + except ImportError: + pytest.skip("Missing Tflite support") device = Device() @@ -158,7 +170,10 @@ def test_squeezenet(): if skip_runtime_test(): return - import tvm.relay.testing.tf as tf_testing + try: + import tvm.relay.testing.tf as tf_testing + except ImportError: + pytest.skip("Missing TF Support") device = Device() diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index 137484330db8..9deaa758639e 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -169,34 +169,37 @@ def test_pooling(): fp32_dtype = ("float32", -127, 128, 0.001, 0.001) uint8_dtype = ("uint8", 0, 255, 1, 0) - + # fmt: off trials = [ - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (27, 27, 512)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (27, 27, 512), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16), (1, 0),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16), (1, 0),], # 20.05: "exclude_padding equal false is not supported for AVG Pooling with padding on quantized types" # ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)], - ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16), (0, 1),], ] - + # fmt: on for ( typef, (dtype, low, high, atol, rtol), size, stride, + dilation, pad, ceil_mode, count_include_pad, input_shape, + (tvm_ops, acl_partitions), ) in trials: shape = (1, *input_shape) outputs = [] @@ -205,7 +208,16 @@ def test_pooling(): } func = _get_pooling_model( - shape, dtype, typef, size, stride, pad, ceil_mode, count_include_pad, iter(inputs) + shape, + dtype, + typef, + size, + stride, + dilation, + pad, + ceil_mode, + count_include_pad, + iter(inputs), ) config = { @@ -215,15 +227,25 @@ def test_pooling(): "pooling type": typef, "dtype": dtype, "padding": pad, + "dilation": dilation, "ceil_mode": ceil_mode, "count_include_pad": count_include_pad, "inputs": inputs, } verify_saturation = True if dtype == "uint8" else False - for acl in [False, True]: outputs.append( - build_and_run(func, inputs, 1, None, device, enable_acl=acl, config=config)[0] + build_and_run( + func, + inputs, + 1, + None, + device, + enable_acl=acl, + tvm_ops=tvm_ops, + acl_partitions=acl_partitions, + config=config, + )[0] ) verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation) @@ -283,25 +305,25 @@ def test_codegen_pooling(): fp32_dtype = ("float32", -127, 128) uint8_dtype = ("uint8", 0, 255) - + # fmt: off trials = [ - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16)], - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16)], - ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (16, 16, 16)], - ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (15, 15, 16)], - ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16), (0, 1),], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16), (1, 0),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16), (1, 0),], + ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (16, 16, 16), (0, 1),], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (15, 15, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16), (0, 1),], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16), (0, 1),], ] - + # fmt: on for ( typef, (dtype, low, high), @@ -312,6 +334,7 @@ def test_codegen_pooling(): ceil_mode, count_include_pad, input_shape, + (tvm_ops, acl_partitions), ) in trials: shape = (1, *input_shape) inputs = {"a"} @@ -319,7 +342,7 @@ def test_codegen_pooling(): func = _get_pooling_model(*args, iter(inputs)) exp_codegen = _get_expected_pooling_codegen(*args) - verify_codegen(func, exp_codegen, 1) + verify_codegen(func, exp_codegen, acl_partitions, tvm_ops) def test_codegen_global_pooling(): diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index b4fc2b283369..2b99879d8227 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -158,13 +158,14 @@ def test_quantized_matmul_add(): def verify_batch_matmul( - batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32" + batch_a, batch_b, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32" ): - ashape = (batch, l, n) if transa else (batch, n, l) - bshape = (batch, m, l) if transb else (batch, l, m) + batch = max(batch_a, batch_b) + ashape = (batch_a, l, n) if transa else (batch_a, n, l) + bshape = (batch_b, m, l) if transb else (batch_b, l, m) A = te.placeholder(ashape, name="A", dtype=dtype) B = te.placeholder(bshape, name="B", dtype=dtype) - C = cblas.batch_matmul(A, B, transa, transb) + C = lib.batch_matmul(A, B, transa, transb) D = te.compute(C.shape, lambda k, i, j: C[k, i, j], name="D") s = te.create_schedule(D.op) @@ -207,24 +208,32 @@ def verify(target="llvm"): def test_batch_matmul(): - verify_batch_matmul(16, 235, 128, 1024, cblas) - verify_batch_matmul(16, 235, 128, 1024, cblas, True, False) - verify_batch_matmul(16, 235, 128, 1024, cblas, False, True) - verify_batch_matmul(16, 235, 128, 1024, cblas, True, True) - verify_batch_matmul(16, 235, 128, 1024, mkl) - verify_batch_matmul(16, 235, 128, 1024, mkl, True, False) - verify_batch_matmul(16, 235, 128, 1024, mkl, False, True) - verify_batch_matmul(16, 235, 128, 1024, mkl, True, True) - verify_batch_matmul(1, 1, 16, 3, cblas) - verify_batch_matmul(1, 1, 16, 3, cblas, True, False) - verify_batch_matmul(1, 1, 16, 3, cblas, False, False) - verify_batch_matmul(1, 1, 16, 3, cblas, True, True) - verify_batch_matmul(1, 1, 16, 3, cblas, iterative=True) - verify_batch_matmul(1, 1, 16, 3, mkl) - verify_batch_matmul(1, 1, 16, 3, mkl, True, False) - verify_batch_matmul(1, 1, 16, 3, mkl, False, False) - verify_batch_matmul(1, 1, 16, 3, mkl, True, True) - verify_batch_matmul(1, 1, 16, 3, mkl, iterative=True) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, False) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, False, True) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, True) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, False) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, False, True) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, True) + verify_batch_matmul(16, 1, 235, 128, 1024, cblas) + verify_batch_matmul(1, 16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 1, 235, 128, 1024, cblas, iterative=True) + verify_batch_matmul(1, 16, 235, 128, 1024, cblas, iterative=True) + verify_batch_matmul(16, 1, 235, 128, 1024, mkl) + verify_batch_matmul(1, 16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 1, 235, 128, 1024, mkl, iterative=True) + verify_batch_matmul(1, 16, 235, 128, 1024, mkl, iterative=True) + verify_batch_matmul(1, 1, 1, 16, 3, cblas) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, False) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, False, False) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, True) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, iterative=True) + verify_batch_matmul(1, 1, 1, 16, 3, mkl) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, False) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, False, False) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, True) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, iterative=True) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index a0f51ca7c9fc..648100a569d7 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -112,33 +112,23 @@ def verify(target="cuda"): verify() -def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5): - j = 16 - n = 1024 - l = 128 - m = 236 - A = te.placeholder((j, n, l), name="A", dtype=in_dtype) - B = te.placeholder((j, l, m), name="B", dtype=in_dtype) +def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): + A = te.placeholder(Ashape, name="A", dtype=in_dtype) + B = te.placeholder(Bshape, name="B", dtype=in_dtype) C = cublas.batch_matmul(A, B, dtype=out_dtype) s = te.create_schedule(C.op) - def verify(target="cuda"): - if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): - print("skip because extern function is not available") - return - dev = tvm.cuda(0) - f = tvm.build(s, [A, B, C], target) - a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), dev) - f(a, b, c) - tvm.testing.assert_allclose( - c.numpy(), - np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), - rtol=rtol, - ) - - verify() + dev = tvm.cuda(0) + f = tvm.build(s, [A, B, C], "cuda") + a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), + rtol=rtol, + ) @tvm.testing.requires_cuda @@ -156,9 +146,20 @@ def test_matmul_add_igemm(): @tvm.testing.requires_cuda def test_batch_matmul(): - verify_batch_matmul("float", "float") - verify_batch_matmul("float16", "float") - verify_batch_matmul("float16", "float16", rtol=1e-2) + if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("skip because extern function is not available") + return + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul( + (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + verify_batch_matmul( + (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) if __name__ == "__main__": diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 8a929f550a4f..7651bdea36a6 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -14,6 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import pytest + import tvm from tvm import te from tvm.contrib import cudnn @@ -23,6 +28,12 @@ import tvm.testing +requires_cudnn = pytest.mark.skipif( + tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None, + reason="CuDNN is not enabled", +) + + def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 out_channel = 16 @@ -38,9 +49,6 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): height = 32 width = 32 - if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): - print("skip because cudnn is not enabled...") - return if data_dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -123,10 +131,6 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): height = 32 width = 32 - if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): - print("skip because cudnn is not enabled...") - return - # schedule xshape = [batch, in_channel, depth, height, width] wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w] @@ -205,11 +209,8 @@ def verify_softmax_4d(shape, dtype="float32"): @tvm.testing.requires_gpu +@requires_cudnn def test_softmax(): - if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): - print("skip because cudnn is not enabled...") - return - verify_softmax((32, 10), -1) verify_softmax((3, 4), -1) verify_softmax((1, 5), -1, "float64") @@ -217,7 +218,84 @@ def test_softmax(): verify_softmax_4d((1, 16, 256, 256), "float64") +test_kwargs_default_2d = { + "tensor_format": 0, + "pad": [1, 1], + "stride": [1, 1], + "dilation": [1, 1], + "x_shape": [16, 4, 32, 32], + "w_shape": [8, 4, 3, 3], + "groups": 1, + "conv_dtype": "float32", + "data_dtype": "float32", +} +test_kwargs_default_3d = { + "tensor_format": 0, + "pad": [1, 1, 1], + "stride": [1, 1, 1], + "dilation": [1, 1, 1], + "x_shape": [16, 4, 32, 32, 32], + "w_shape": [8, 4, 3, 3, 3], + "groups": 1, + "conv_dtype": "float32", + "data_dtype": "float32", +} +conv_output_shape_conditions = { + "2d_small": test_kwargs_default_2d, + "2d_large": { + **test_kwargs_default_2d, + "x_shape": [16, 32, 512, 1024], + "w_shape": [8, 32, 5, 5], + }, + "2d_pad": {**test_kwargs_default_2d, "pad": [2, 3]}, + "2d_stride": {**test_kwargs_default_2d, "stride": [2, 3]}, + "2d_dilation": {**test_kwargs_default_2d, "dilation": [2, 3]}, + "2d_groups": {**test_kwargs_default_2d, "groups": 4, "w_shape": [8, 1, 3, 3]}, + "2d_NHWC": { + **test_kwargs_default_2d, + "tensor_format": 1, + "x_shape": [16, 32, 32, 4], + "w_shape": [8, 3, 3, 4], + }, + "2d_NCHW_VECT_C": { + **test_kwargs_default_2d, + "tensor_format": 2, + "w_shape": [8, 16, 3, 3], + "data_dtype": "int8x4", + }, + "3d_small": test_kwargs_default_3d, + "3d_large": { + **test_kwargs_default_3d, + "x_shape": [16, 32, 64, 128, 256], + "w_shape": [8, 32, 5, 5, 5], + }, + "3d_pad": {**test_kwargs_default_3d, "pad": [2, 3, 4]}, + "3d_stride": {**test_kwargs_default_3d, "stride": [2, 3, 4]}, + "3d_dilation": {**test_kwargs_default_3d, "dilation": [2, 3, 4]}, + "3d_groups": {**test_kwargs_default_3d, "groups": 4, "w_shape": [8, 1, 3, 3, 3]}, + "3d_NCHW_VECT_C": { + **test_kwargs_default_3d, + "tensor_format": 2, + "w_shape": [8, 16, 3, 3, 3], + "data_dtype": "int8x4", + }, +} + + +@pytest.fixture( + params=[pytest.param(kwargs, id=name) for name, kwargs in conv_output_shape_conditions.items()] +) +def conv_output_shape_kwargs(request): + return request.param + + +@tvm.testing.requires_gpu +@requires_cudnn +def test_conv_output_shape(conv_output_shape_kwargs): + shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs) + shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs) + assert shape_from_cudnn == shape_from_python + + if __name__ == "__main__": - test_conv2d() - test_conv3d() - test_softmax() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index ba03acc1c112..92e8f11a2312 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -326,5 +326,5 @@ def get_ethosn_api_version(): def get_ethosn_variant(): ethosn_variant_config = os.getenv("ETHOSN_VARIANT_CONFIG") if ethosn_variant_config is not None: - return 3 - return 0 + return "Ethos-N78_1TOPS_2PLE_RATIO" + return "Ethos-N77" diff --git a/tests/python/contrib/test_ethosn/test_conv2d.py b/tests/python/contrib/test_ethosn/test_conv2d.py index ca551603d13f..845cec593105 100644 --- a/tests/python/contrib/test_ethosn/test_conv2d.py +++ b/tests/python/contrib/test_ethosn/test_conv2d.py @@ -188,10 +188,6 @@ def test_conv2d_failure(): _scale_error_msg = ( "Overall scale (of the input * weights / output) should be in the range [0, 1)" ) - if tei.get_ethosn_api_version() == 2008: - _scale_error_msg = ( - "Overall scale (of the input * weights / output) should be in the range [0, 1}" - ) trials = [ ( diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index ce89c90d9379..f9a3549576c3 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -123,15 +123,11 @@ def test_mobilenet_v1(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"bfb5a50607edb50009c58ae9d4287e4d"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"896c28b4f06341ea638ead3a593e1aed"} - if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"47e216d8ab2bf491708ccf5620bc0d02"} - if tei.get_ethosn_variant() == 3: - _compile_hash = {"2436f523e263f66a063cef902f2f43d7"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"9298b6c51e2a82f70e91dd11dd6af412"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"407eb47346c8afea2d15e8f0d1c079f2"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" @@ -153,15 +149,11 @@ def test_inception_v3(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"96116d7e6c7385de0688074a3f889983"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"551cde850c6ef960d19be4f317fb8e68"} - if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"8c9d75659cd7bc9ff6dd6d490d28f9b2"} - if tei.get_ethosn_variant() == 3: - _compile_hash = {"cdd4d7f6453d722ea73224ff9d6a115a"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"d44eece5027ff56e5e7fcf014367378d"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"1ba555b4bc60c428018a0f2de9d90532"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" @@ -182,17 +174,11 @@ def test_inception_v4(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"b34aec2a48c591818761ed6b42c133e5"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"30f078bd42757e8686eafa1f28d0d352"} - if tei.get_ethosn_api_version() == 2008: - if not tei.get_ethosn_variant() == 0: - pytest.skip( - "Ethos-N78 20.08 does not support inception_v4 in the default configuration." - ) - _compile_hash = {"798292bfa596ca7c32086396b494b46c"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"53f126cf654d4cf61ebb23c767f6740b"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"851665c060cf4719248919d17325ae02"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" @@ -213,15 +199,11 @@ def test_ssd_mobilenet_v1(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = {"c312edfc9a946ed4dc7c049d472dae6e", "3183f0fa5eba8f6b9557d14eaf47842d"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"deee52e136327436411fc725624ae2ea", "6526509d3cbee014e38c79e22bb29d7f"} - if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"5999f26e140dee0d7866491997ef78c5", "24e3a690a7e95780052792d5626c85be"} - if tei.get_ethosn_variant() == 3: - _compile_hash = {"da871b3f03a93df69d704ed44584d6cd", "9f52411d301f3cba3f6e4c0f1c558e87"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"6e8c4586bdd26527c642a4f016f52284", "057c5efb094c79fbe4483b561147f1d2"} - if tei.get_ethosn_variant() == 3: + if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": _compile_hash = {"dc687e60a4b6750fe740853f22aeb2dc", "1949d86100004eca41099c8e6fa919ab"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 476fac5da1b9..cb6b82a32937 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -19,6 +19,7 @@ import pytest import tvm +from tvm.contrib.target.vitis_ai import vitis_ai_available from tvm.driver import tvmc from tvm.driver.tvmc.common import TVMCException @@ -306,3 +307,53 @@ def test_parse_quotes_and_separators_on_options(): assert len(targets_double_quote) == 1 assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] + + +def test_config_invalid_format(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) + + +def test_config_missing_from_tvm(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) + + +def test_config_unsupported_tvmc_config(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) + + +def test_config_empty(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs([""]) + + +def test_config_valid_config_bool(): + configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) + + assert len(configs) == 1 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == True + + +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_config_valid_multiple_configs(): + configs = tvmc.common.parse_configs( + [ + "relay.backend.use_auto_scheduler=false", + "tir.detect_global_barrier=10", + "relay.ext.vitis_ai.options.build_dir=mystring", + ] + ) + + assert len(configs) == 3 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == False + assert "tir.detect_global_barrier" in configs.keys() + assert configs["tir.detect_global_barrier"] == 10 + assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() + assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 362a9b623d25..a6c3d6efec56 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -14,22 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np import operator +import random +import numpy as np +import pytest import tvm -from tvm import te +import tvm.testing +from tvm import relay, te from tvm.contrib import graph_executor -from tvm import relay -import mxnet as mx +import model_zoo +import mxnet as mx from mxnet import gluon from mxnet.gluon.model_zoo import vision -import random -import pytest -import model_zoo - -import tvm.testing def verify_mxnet_frontend_impl( @@ -1231,7 +1229,9 @@ def verify(shape, axis=1, epsilon=1e-5): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) op_res = intrp.evaluate()(x, gamma, beta) - tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=2e-5, atol=1e-5 + ) verify((2, 3, 4, 5)) verify((32, 64, 80, 64)) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index afd38d1c952d..cd1b80bf65d3 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1851,6 +1851,9 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True) _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) + _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False) + _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False) + _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) @tvm.testing.requires_cuda @@ -1878,6 +1881,20 @@ def test_forward_batch_matmul_dynamic(): (2, 3, 4, 6, 5), "float32", ) + _test_batch_matmul_dynamic( + (None, None, None, 5, 6), + (6, None), + (2, 3, 4, 5, 6), + (6, 1), + "float32", + ) + _test_batch_matmul_dynamic( + (None, 5, 6), + (6, None), + (24, 5, 6), + (6, 1), + "float32", + ) ####################################################################### @@ -5560,5 +5577,23 @@ def @main(%A: Tensor[(4, 176, 8, 8), float32]) { tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True) +####################################################################### +# invert_permutation +# -------------------- + + +def test_invert_permutation(): + """test InvertPermutation""" + tf.reset_default_graph() + + input_shape = [6] + x = np.array([3, 4, 0, 2, 1, 5]).astype("int32") + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=input_shape, dtype="int32") + tf.invert_permutation(in_data) + out_name = "InvertPermutation:0" + compare_tf_with_tvm(x, "Placeholder:0", out_name, no_gpu=False) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 4f8de450d9f1..ccdc7160881c 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -340,7 +340,7 @@ def visit_call(self, call): @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) @pytest.mark.parametrize("target_options", [""]) -def test_byoc_utvm(use_calculated_workspaces, target_options): +def test_byoc_microtvm(use_calculated_workspaces, target_options): """This is a simple test case to check BYOC capabilities of AOT""" x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 57f07b3f00e5..13f5525bfee8 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -508,7 +508,7 @@ def verify_any_conv2d( kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) targets = None - if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): + if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True): targets = [("cuda -libs=cudnn", tvm.cuda(0))] check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 229b9905050c..1c721f40d129 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -478,11 +478,17 @@ def test_no_match_func_attr(): def test_match_call_attr(): + # String attr is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"}) x = relay.var("x") y = relay.var("y") assert is_conv2d.match(relay.op.nn.conv2d(x, y)) + # Array attr + is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]}) + out = relay.op.nn.conv2d(x, y, kernel_size=[3, 3]) + assert is_conv2d.match(out) + # non-operator call attr_dict = {"call_attr": "attr"} call_has_attr = wildcard()(wildcard()).has_attr(attr_dict) @@ -508,6 +514,11 @@ def test_no_match_call_attr(): is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"}) assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + # Array attr + is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]}) + out = relay.op.nn.conv2d(x, y, kernel_size=[2, 1]) + assert not is_conv2d.match(out) + # non-operator calls call_has_attr = wildcard()(wildcard()).has_attr({"call_attr": "attr"}) wrong_key = tvm.ir.make_node("DictAttrs", **{"wrong": "attr"}) @@ -1598,7 +1609,7 @@ def callback(self, pre, post, node_map): return x + w out = rewrite(TestRewrite(), expr) - assert tvm.ir.structural_equal(x + w, x + w) + assert tvm.ir.structural_equal(out, x + w + b) def test_partition_function_with_fuzzy_body(): diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 4968660b95c8..b4d02e4815fb 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -276,5 +276,14 @@ def test_span(): assert "Add1" in txt +def test_optional_info(): + c = relay.const(1) + call = relay.add(c, c) + m = tvm.IRModule.from_expr(call) + m = relay.transform.InferType()(m) + txt = astext(m) + assert txt.count("/* ty=int32 */") == 3 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 4710d50ea8e4..88590c946e88 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1797,6 +1797,90 @@ def expected(): _test_conv_reduce_convert_layout2() +def test_image_resize_convert_layout(): + def _test_image_resize_convert_layout_nchw_to_nhwc(): + def before(): + x = relay.var("x", shape=(1, 2, 4, 4)) + y = relay.image.resize(x, (8, 8)) + y = relay.Function([x], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 2, 4, 4)) + x = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def _test_image_resize_convert_layout_nhwc_to_nchw(): + def before(): + x = relay.var("x", shape=(1, 4, 4, 2)) + y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.Function([x], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 4, 4, 2)) + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.image.resize(x, (8, 8), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + _test_image_resize_convert_layout_nchw_to_nhwc() + _test_image_resize_convert_layout_nhwc_to_nchw() + + +def test_conv_image_resize_convert_layout(): + """Check that layout transforms are propagated through image resize.""" + + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.image.resize(y, (112, 112), layout="NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, "NHWC", "NCHW") + w = relay.layout_transform(w, "HWIO", "OIHW") + y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.image.resize(y, (112, 112), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1828,3 +1912,5 @@ def expected(): test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() test_conv_strided_slice_axes_convert_layout() + test_image_resize_convert_layout() + test_conv_image_resize_convert_layout() diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index f45e39047238..1312b396fe4c 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -36,18 +36,18 @@ def run_opt_pass(expr, passes): @tvm.testing.uses_gpu -def test_legalize_conv2d(): - """test legalize conv2d to enable tensorcore""" +def test_legalize_conv2d_NHWC(): + """test legalize NHWC conv2d to enable tensorcore""" - def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): out_channel = kernel_shape[3] out_shape = list(data_shape) out_shape[3] = out_channel db, di, do = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.conv2d( x, weight, @@ -67,12 +67,12 @@ def legalize_conv2d(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if db or di: x_pad = relay.nn.pad(x, pad_width=((0, db), (0, 0), (0, 0), (0, di))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if di or do: weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, di), (0, do))) else: @@ -99,19 +99,109 @@ def expected(): b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + for dtype in ["float16", "int8", "int4"]: + # conv2d pad batch + _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0), dtype) + _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0), dtype) + _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), dtype, False) + # conv2d pad in_channel + _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0), dtype) + _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0), dtype) + _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0), dtype) + _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), dtype, False) + # conv2d pad out_channel + _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1), dtype) + _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31), dtype) + _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), dtype, False) + + +@tvm.testing.uses_gpu +def test_legalize_conv2d_HWNC(): + """test legalize HWNC conv2d to enable tensorcore""" + + def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): + out_channel = kernel_shape[2] + out_shape = list(data_shape) + out_shape[3] = out_channel + db, di, do = pad_shape + + def before(): + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) + y = relay.nn.conv2d( + x, + weight, + channels=out_channel, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="HWNC", + kernel_layout="HWOI", + ) + y = relay.Function([x, weight], y) + return y + + def legalize_conv2d(attrs, inputs, types): + with tvm.target.Target("cuda"): + return topi.nn.conv2d_legalize(attrs, inputs, types) + + def expected(): + if not do_pad: + return before() + x = relay.var("x", shape=data_shape, dtype=dtype) + if db or di: + x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, 0), (0, db), (0, di))) + else: + x_pad = x + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) + if di or do: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, do), (0, di))) + else: + weight_pad = weight + y_pad = relay.nn.conv2d( + x_pad, + weight=weight_pad, + channels=out_channel + do, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="HWNC", + kernel_layout="HWOI", + ) + if db or do: + y = relay.strided_slice(y_pad, begin=[0, 0, 0, 0], end=out_shape) + else: + y = y_pad + y = relay.Function([x, weight], y) + return y + + with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): + a = before() + a = run_opt_pass(a, transform.Legalize()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + # conv2d pad batch - _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0)) - _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0)) - _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), False) + _test_legalize_conv2d((16, 16, 7, 64), (3, 3, 64, 64), (1, 0, 0), "int8") + _test_legalize_conv2d((16, 16, 3, 64), (3, 3, 64, 64), (5, 0, 0), "int8") + _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), "int8", False) + _test_legalize_conv2d((16, 16, 7, 64), (3, 3, 64, 64), (1, 0, 0), "int4") + _test_legalize_conv2d((16, 16, 3, 64), (3, 3, 64, 64), (5, 0, 0), "int4") + _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), "int4", False) # conv2d pad in_channel - _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0)) - _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0)) - _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0)) - _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), False) + _test_legalize_conv2d((16, 16, 8, 63), (3, 3, 64, 63), (0, 1, 0), "int8") + _test_legalize_conv2d((16, 16, 8, 33), (3, 3, 64, 33), (0, 15, 0), "int8") + _test_legalize_conv2d((16, 16, 8, 13), (3, 3, 64, 13), (0, 3, 0), "int8") + _test_legalize_conv2d((16, 16, 8, 1), (3, 3, 64, 1), (0, 0, 0), "int8", False) + _test_legalize_conv2d((16, 16, 8, 63), (3, 3, 64, 63), (0, 1, 0), "int4") + _test_legalize_conv2d((16, 16, 8, 33), (3, 3, 64, 33), (0, 31, 0), "int4") + _test_legalize_conv2d((16, 16, 8, 13), (3, 3, 64, 13), (0, 19, 0), "int4") + _test_legalize_conv2d((16, 16, 8, 1), (3, 3, 64, 1), (0, 0, 0), "int4", False) # conv2d pad out_channel - _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1)) - _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31)) - _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), False) + _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 63, 64), (0, 0, 1), "int8") + _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 33, 64), (0, 0, 31), "int8") + _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 1, 64), (0, 0, 0), "int8", False) + _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 63, 64), (0, 0, 1), "int4") + _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 33, 64), (0, 0, 7), "int4") + _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 1, 64), (0, 0, 0), "int4", False) @tvm.testing.uses_gpu @@ -234,6 +324,7 @@ def expected(): if __name__ == "__main__": - test_legalize_conv2d() + test_legalize_conv2d_NHWC() + test_legalize_conv2d_HWNC() test_legalize_dense() test_legalize_batch_matmul() diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py new file mode 100644 index 000000000000..caccd52d60c2 --- /dev/null +++ b/tests/python/relay/test_to_mixed_precision.py @@ -0,0 +1,446 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for testing ToMixedPrecision pass""" +from typing import Any, Dict, List + +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.testing import lstm +from tvm.relay.transform import InferType, ToMixedPrecision, mixed_precision + + +def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: + dev = tvm.device("llvm", 0) + intrp = relay.create_executor("debug", mod, device=dev, target="llvm") + result = intrp.evaluate()(**mod_params) + if isinstance(result, tvm.runtime.container.ADT): + result = [r.asnumpy() for r in result] + return result + else: + return [result.asnumpy()] + + +def verify_mixed_precision_output_close( + mod: tvm.runtime.Module, + mod_params: Dict[str, Any], + mixed_precision_dtype="float16", + rtol: float = 1e-3, + atol: float = 0, +) -> tvm.runtime.Module: + + mod = InferType()(mod) + result_fp32 = run_module(mod, mod_params) + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_fp16 = run_module(fp16_mod, mod_params) + # Ensure the results are close + for fp32, fp16 in zip(result_fp32, result_fp16): + np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) + + return fp16_mod + + +def test_lstm(): + """A small stress test on a single unrolled lstm unit. + + Has internal functions and let statements the pass must work on. + """ + units = 3 + iterations = 5 + mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) + + # This is an unrolled lstm so each data should be the previous results but + # we don't care, we just want to stress test things. + for i in range(iterations): + mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform( + -10, 10, (1, units) + ).astype("float32") + + verify_mixed_precision_output_close(mod, mod_params, rtol=0.01, atol=0.01) + + +def test_lstm_float64(): + """Tests if can handle other mixed precision types. + + As a toy example show can convert graph to float64 and have it run. + + It doesn't really make sense to do it, this just shows we can change + the target mixed_precision_dtype. + """ + units = 3 + iterations = 5 + mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) + + # This is an unrolled lstm so each data should be the previous results but + # we don't care, we just want to stress test things. + for i in range(iterations): + mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform( + -10, 10, (1, units) + ).astype("float32") + + verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype="float64", rtol=0.01, atol=0.01 + ) + + +def test_convert_single_conv(): + """Conv is a green listed operation meaning it will always use fp16 workload. + + By default it accumulates to fp32 and outputs fp16. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_convert_single_conv_fp64(): + """As above but checks choosing a mixed_precision_type other than FP16 works""" + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype="float64", atol=0.01, rtol=1e-3 + ) + + # Note we still accumulate to FP32 by default, a user would need to overwrite default + # behavior to make this make more sense. + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float64"), + relay.cast(weight, "float64"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float64", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_convert_conv_bn(): + """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + + bn_shape = [5] + gamma = relay.var("gamma", shape=bn_shape) + beta = relay.var("beta", shape=bn_shape) + moving_mean = relay.var("moving_mean", shape=bn_shape) + moving_var = relay.var("moving_var", shape=bn_shape) + bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) + mod = tvm.IRModule.from_expr(bn[0]) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + "gamma": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "beta": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + # Creating expected module + data = relay.cast(relay.var("data", shape=data_shape), "float16") + weight = relay.cast(relay.var("weight", shape=weight_shape), "float16") + conv = relay.cast( + relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"), + "float16", + ) + + bn_shape = [5] + gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16") + beta = relay.cast(relay.var("beta", shape=bn_shape), "float16") + moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16") + moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16") + bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) + + expected_mod = tvm.IRModule.from_expr(bn[0]) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_do_not_convert_softmax(): + """Softmax is a red listed operation and therefore should never be fp16.""" + shape = [1, 2, 3] + a = relay.var("a", shape=shape) + b = relay.nn.softmax(a) + mod = tvm.IRModule.from_expr(b) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "a": np.random.uniform(-1, 1, size=shape).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) + assert tvm.ir.structural_equal(mod, output_mod) + + +def test_green_gray_propagates_simple(): + """Conv is a green listed operation, while addition is gray. + + As Conv outputs fp16 the add should be done in fp16. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + conv = conv + conv + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + conv_expr = relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_green_red_not_use_extraneous_cast(): + """Conv. is a green listed operation, while softmax is red. + + Conv. also by default accumulates to fp32 but outputs fp16. + + We want to avoid a situation where we have extraneous casts. + E.g. because softmax wants to operate on FP32 we might have + + conv (FP32) -> cast (FP16) -> cast (FP32) -> softmax (FP32) + + To get around this internally when we cast in the pass we cache + the output nodes and the reverse of the cast back to the original + node. For example casting the `conv (FP32)` to FP16 would produce: + + `conv (FP32) -> cast (FP16)` + + As the outputs. Now anytime we try to cast the `conv (FP32)` node + to FP16 it would return the cached result instead of a new cast node: + + `conv (FP32) -> cast (FP16)` + + Furthermore, if we try to cast the `cast (FP16)` node back to FP32 it + would just return + + `conv (FP32)`. + + This test makes sure this behavior occurs. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + result = relay.nn.softmax(conv) + mod = tvm.IRModule.from_expr(result) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + # Construct expected structure + conv = relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ) + result = relay.nn.softmax(conv) + expected_mod = tvm.IRModule.from_expr(result) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, fp16_mod) + + +def test_red_gray_propagates_simple(): + """Everything after a softmax should be in FP32 (exception green colored ops)""" + shape = [1, 2, 3] + a = relay.var("a", shape=shape) + b = relay.nn.softmax(a) + c = b + b + mod = tvm.IRModule.from_expr(c) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "a": np.random.uniform(-1, 1, size=shape).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0.0) + + assert tvm.ir.structural_equal(mod, output_mod) + + +def test_let_statement_simple(): + """A 'simple' let statement example. + + Noticeable is the mutation of the bound variable types. + """ + var1 = relay.var("var1", shape=[1, 20]) + var2 = relay.var("var2", shape=[1, 20]) + + data = relay.var("data", shape=[1, 20]) + weight = relay.var("weight", shape=[20, 20]) + + r1 = var1 + var1 + + r2 = var2 + var2 + let2 = relay.Let(var2, relay.nn.dense(r1, weight, units=20), r2) + let1 = relay.Let(var1, relay.nn.dense(data, weight, units=20), let2) + + mod = tvm.IRModule.from_expr(let1) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Construct expected structure + var1 = relay.var("var1", shape=[1, 20], dtype="float16") + var2 = relay.var("var2", shape=[1, 20], dtype="float16") + data = relay.cast(relay.var("data", shape=[1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + r1 = var1 + var1 + r2 = var2 + var2 + let2 = relay.Let( + var2, + relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"), + r2, + ) + let1 = relay.Let( + var1, + relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"), + let2, + ) + expected_mod = tvm.IRModule.from_expr(let1) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_where_simple(): + data = relay.var("data", shape=[1, 20]) + weight = relay.var("weight", shape=[20, 20]) + a = relay.nn.dense(data, weight, units=20) + b = relay.where(data, a, a) + mod = tvm.IRModule.from_expr(b) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), + } + + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16") + b = relay.where(data, a, a) + expected_mod = tvm.IRModule.from_expr(b) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_batch_matmul_simple(): + """Batch matmul is a special case where we try to accumulate to fp16. + + This is due to the fact heterogenous accumulation dtypes does not work + on all platforms at the moment. + """ + data = relay.var("data", shape=[1, 1, 20]) + weight = relay.var("weight", shape=[1, 20, 20]) + a = relay.nn.batch_matmul(data, weight) + mod = tvm.IRModule.from_expr(a) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16") + a = relay.nn.batch_matmul(data, weight, out_dtype="float16") + expected_mod = tvm.IRModule.from_expr(a) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 182645117f36..f90c9548ec02 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -401,33 +401,62 @@ def test_rpc_tracker_register(): # test registration tracker = Tracker(port=9000, port_end=10000) device_key = "test_device" - server = rpc.Server( + server1 = rpc.Server( + host="127.0.0.1", + port=9000, + port_end=10000, + key=device_key, + tracker_addr=("127.0.0.1", tracker.port), + ) + server2 = rpc.Server( + host="127.0.0.1", port=9000, port_end=10000, key=device_key, tracker_addr=("127.0.0.1", tracker.port), + custom_addr="test_addr", # this is a test address, which is unable to connect ) time.sleep(1) client = rpc.connect_tracker("127.0.0.1", tracker.port) + def exist_address(summary, key, host, port): + server_info = summary["server_info"] + for device in server_info: + if device["key"] == "server:%s" % key: + addr = device["addr"] + if (host is None or host == addr[0]) and port == addr[1]: + return True + return False + summary = client.summary() - assert summary["queue_info"][device_key]["free"] == 1 + assert summary["queue_info"][device_key]["free"] == 2 + assert exist_address(summary, device_key, "127.0.0.1", server1.port) + assert exist_address(summary, device_key, "test_addr", server2.port) remote = client.request(device_key) summary = client.summary() - assert summary["queue_info"][device_key]["free"] == 0 + assert summary["queue_info"][device_key]["free"] == 1 del remote time.sleep(1) + summary = client.summary() + assert summary["queue_info"][device_key]["free"] == 2 + + server1.terminate() + time.sleep(1) + summary = client.summary() assert summary["queue_info"][device_key]["free"] == 1 + assert not exist_address(summary, device_key, "127.0.0.1", server1.port) + assert exist_address(summary, device_key, "test_addr", server2.port) - server.terminate() + server2.terminate() time.sleep(1) summary = client.summary() assert summary["queue_info"][device_key]["free"] == 0 + assert not exist_address(summary, device_key, "test_addr", server2.port) tracker.terminate() diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index dc165331729e..c8cddf8b9598 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -16,6 +16,8 @@ # under the License. import re +import sys + import numpy as np import tvm @@ -357,12 +359,63 @@ def do_compute(A, B, n): tvm.testing.assert_allclose(b.numpy(), [210]) +@tvm.testing.parametrize_targets("vulkan") +def test_vulkan_local_threadidx(target, dev): + # To access the thread index, the vulkan runtime accesses a global + # array of thread indices, storing the result in a local variable. + # In CUDA, these are the built-in threadIdx.x variables, which are + # globally accessible. In vulkan, these local variables must be + # defined inside a function, but are hoisted up to the function + # header to mimic the global CUDA semantics. Before this + # hoisting, this test could trigger spvValidate errors for + # potentially undeclared variables. + + def do_compute(A, B, n): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + + # One single declaration of te.thread_axis. + tx = te.thread_axis("threadIdx.x") + + with ib.for_range(0, 1): + # Used inside a for-loop scope, defines local thread_id + # variable. + ib.scope_attr(tx, "thread_extent", 16) + B[tx + 0] = A[tx + 0] + + with ib.for_range(0, 1): + # Used in next scope. If local variable defined at point + # of use instead of function header, will fail spvValidate + # for access of out-of-scope local variable. + ib.scope_attr(tx, "thread_extent", 16) + B[tx + 16] = A[tx + 16] + + return ib.get() + + n = te.var("n") + A = te.placeholder((n,), name="A", dtype="int32") + B = te.placeholder((n,), name="B", dtype="int32") + + B = te.extern( + A.shape, + [A], + lambda ins, outs: do_compute(ins[0], outs[0], n), + dtype="int32", + ) + s = te.create_schedule(B.op) + + # Expected failure occurs at build step. + func = tvm.build(s, [A, B], target) + + n = 32 + a_np = np.arange(n).astype(dtype=A.dtype) + b_np = np.zeros((n,), dtype="int32") + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + func(a, b) + tvm.testing.assert_allclose(b.numpy(), a_np) + + if __name__ == "__main__": - test_vector_comparison() - test_vulkan_copy() - test_vulkan_vectorize_add() - test_vulkan_stress() - test_vulkan_constant_passing() - test_vulkan_bool_load() - test_vulkan_pushconstants() - test_vulkan_unique() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 98a9edc7a517..3ed275800bd2 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -299,5 +299,16 @@ def test_check_and_update_host_consist_3(): assert target.host == host +def test_target_attr_bool_value(): + target0 = Target("llvm --link-params=True") + assert target0.attrs["link-params"] == 1 + target1 = Target("llvm --link-params=true") + assert target1.attrs["link-params"] == 1 + target2 = Target("llvm --link-params=False") + assert target2.attrs["link-params"] == 0 + target3 = Target("llvm --link-params=false") + assert target3.attrs["link-params"] == 0 + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index c5163a8457af..1e3c8061e029 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -20,7 +20,7 @@ def lower_stmt(sche, params, passfunc): - func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(sche, params, "main", None)["main"] func = passfunc()(tvm.IRModule.from_expr(func))["main"] stmt = func.body return stmt @@ -42,7 +42,7 @@ def get_promoted(op): lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], "float")), "bfloat16"), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body def test_promoted(op): @@ -111,7 +111,7 @@ def get_target(): ), ) s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] return func.body tvm.ir.assert_structural_equal(get_eliminated(), get_target()) @@ -151,7 +151,7 @@ def check(fcompute_before, fcompute_after): b = te.placeholder((100,), dtype="uint16", name="B") c = te.compute((100,), fcompute_after(a, b), name="C") s = te.create_schedule(c.op) - func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] tvm.ir.assert_structural_equal(stmt, func.body) def orig1(a, b): diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 7d02e4f12c1d..252a187dbdc5 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -522,7 +522,7 @@ def test_hoisting_block_scope_1(): s[B.op].bind(xi, te.thread_axis("threadIdx.y")) s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) - func = tvm.driver.build_module.form_irmodule(s, [A, B], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) @@ -622,7 +622,7 @@ def test_hoisting_block_scope_4(): s[C].pragma(xo2, "parallel_stride_pattern") s[C].pragma(xo2, "parallel_barrier_when_finish") s[C].vectorize(xi) - func = tvm.driver.build_module.form_irmodule(s, [A, B, C], "main", None)["main"] + func = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py new file mode 100644 index 000000000000..016e7f7427f6 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import ty +from tvm import te, tir +import numpy as np +import tvm.testing + + +@tvm.script.tir +def get_valid_counts( + data: ty.handle, + valid_count: ty.handle, + out: ty.handle, + out_indices: ty.handle, + score_threshold: ty.float32, + id_index: ty.int32, + score_index: ty.int32, +) -> None: + + data_buf = tir.match_buffer(data, (1, 2500, 6), "float32") + valid_count_buf = tir.match_buffer(valid_count, (1,), "int32") + out_buf = tir.match_buffer(out, (1, 2500, 6), "float32") + out_indices_buf = tir.match_buffer(out_indices, (1, 2500), "int32") + + with tir.block([1], "init") as [vi]: + valid_count_buf[vi] = tir.int32(0) + with tir.block([2500], "update") as [vj]: + tir.reads([data_buf[vi, vj, 6]]) + tir.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) + if (data_buf[vi, vj, score_index] > score_threshold) and ( + (id_index < 0) or (data_buf[vi, vj, id_index] >= tir.float32(0)) + ): + for k in tir.serial(0, 6): + out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] + out_indices_buf[vi, valid_count_buf[vi]] = vj + valid_count_buf[vi] = valid_count_buf[vi] + 1 + if vj >= valid_count_buf[vi]: + for k in tir.serial(0, 6): + out_buf[vi, vj, k] = tir.float32(-1) + out_indices_buf[vi, vj] = tir.int32(-1) + + +def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, score_index): + dtype = "float32" + ctx = tvm.cpu() + batch_size, num_anchor, elem_length = dshape + np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) + np_out1 = np.zeros(shape=(batch_size,), dtype="int32") + np_out2 = np.zeros(shape=dshape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor), dtype="int32") + for i in range(batch_size): + np_out1[i] = 0 + inter_idx = 0 + for j in range(num_anchor): + score = np_data[i, j, score_index] + if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): + for k in range(elem_length): + np_out2[i, inter_idx, k] = np_data[i, j, k] + np_out1[i] += 1 + np_out3[i, inter_idx] = j + inter_idx += 1 + if j >= np_out1[i]: + for k in range(elem_length): + np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 + + in_data = tvm.nd.array(np_data, ctx) + score_threshold_data = tvm.nd.array(np.array([score_threshold], dtype=dtype), ctx) + out1 = tvm.nd.array(np_out1, ctx) + out2 = tvm.nd.array(np_out2, ctx) + out3 = tvm.nd.array(np_out3, ctx) + f(in_data, out1, out2, out3, score_threshold, id_index, score_index) + tvm.testing.assert_allclose(out1.numpy(), np_out1, rtol=1e-5) + tvm.testing.assert_allclose(out2.numpy(), np_out2, rtol=1e-5) + tvm.testing.assert_allclose(out3.numpy(), np_out3, rtol=1e-5) + print("test get_valid_counts end") + + +def test_get_valid_counts_script_func(): + device = "llvm" + # check lowering + print(tvm.script.asscript(get_valid_counts)) + mod = tvm.script.create_module({"get_valid_counts": get_valid_counts}) + print(tvm.script.asscript(mod)) + # check building + f = tvm.build(mod["get_valid_counts"], target=device) + _check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1) + + +if __name__ == "__main__": + test_get_valid_counts_script_func() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index e84902f0540e..164949552859 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2888,6 +2888,64 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 +@tvm.script.tir +def rank0(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + B = tir.alloc_buffer((), "float32") + A[()] = 2 + B[()] = A[()] + + +def test_rank0_buffers(): + func = rank0 + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +@tvm.script.tir +def rank0_block(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + B = tir.alloc_buffer((), "float32") + tir.store(B.data, 0, tir.load("float32", A.data, 0)) + + with tir.block([], "update") as []: + tir.reads([A[()]]) + tir.writes([B[()]]) + for i in range(0, 1): + B[()] = A[()] + + +def test_rank0_blocks(): + func = rank0_block + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +@tvm.script.tir +def select(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + A[()] = tir.Select(True, 1, 2) + + +def test_select(): + func = select + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +@tvm.script.tir +def minmax(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + A[()] = tir.min(1, 2) + A[()] = tir.max(1, 2) + + +def test_minmax(): + func = minmax + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + if __name__ == "__main__": test_opt_gemm_normalize() test_opt_gemm_mod_host() diff --git a/tests/scripts/task_config_build_arm.sh b/tests/scripts/task_config_build_arm.sh index b3a084aef371..cae28467830f 100755 --- a/tests/scripts/task_config_build_arm.sh +++ b/tests/scripts/task_config_build_arm.sh @@ -34,3 +34,4 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake +echo set\(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "/opt/acl"\) >> config.cmake diff --git a/tutorials/get_started/install.py b/tutorials/get_started/install.py index efc951a52709..e022e4b1ae2e 100644 --- a/tutorials/get_started/install.py +++ b/tutorials/get_started/install.py @@ -32,7 +32,7 @@ # ---------------------- # Installing from source is the recommended method for installing TVM. It will # allow you to enable specific features such as GPU support, microcontroller -# support (uTVM), and a debugging runtime, and other features. You will also +# support (microTVM), and a debugging runtime, and other features. You will also # want to install from source if you want to actively contribute to the TVM # project. The full instructions are on the `Install TVM From Source # `_ page.