From 3e7916d30ad2b453dc6879275551b5c6a36f14cc Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Tue, 31 May 2022 11:11:14 -0700 Subject: [PATCH 001/181] [ci][docker] Prune all non-relevant images (#11497) * [skip ci][ci][docker] Prune all non-relevant images (#11491) Before this would leave around any image that could be used in CI. This PR changes it so that the `docker rmi` knows exactly which image is being used in CI so all others (even those that are being used in the same build but not currently on that node) are deleted This also adds some more logging so we can see what's going on and should help keep disk usage down. Co-authored-by: driazati * [skip ci] Revert "[skip ci][ci][docker] Prune all non-relevant images (#11491)" (#11496) * [ci][docker] Prune all non-relevant images (this is a re-do of #11491) Before this would leave around any image that could be used in CI. This PR changes it so that the `docker rmi` knows exactly which image is being used in CI so all others (even those that are being used in the same build but not currently on that node) are deleted This also adds some more logging so we can see what's going on and should help keep disk usage down. Skipped CI since this runs during lint. Co-authored-by: driazati --- Jenkinsfile | 88 +++++++++++++++++++++++++++++++---- jenkins/Build.groovy.j2 | 7 +++ jenkins/DockerBuild.groovy.j2 | 8 ++++ jenkins/Lint.groovy.j2 | 1 + jenkins/Prepare.groovy.j2 | 23 +++++++-- jenkins/Test.groovy.j2 | 15 +++++- jenkins/macros.j2 | 9 ++-- 7 files changed, 134 insertions(+), 17 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index d239d362f9ae3..44389ba767dc7 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,7 +45,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-05-26T15:43:31.409794 +// Generated at 2022-05-27T14:45:11.226042 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> @@ -108,11 +108,7 @@ def per_exec_ws(folder) { def init_git() { checkout scm - // Clear out all Docker images that aren't going to be used - sh( - script: "docker image ls --all --format '{{.Repository}}:{{.Tag}} {{.ID}}' | { grep -vE '${ci_arm}|${ci_cpu}|${ci_gpu}|${ci_hexagon}|${ci_i386}|${ci_lint}|${ci_qemu}|${ci_wasm}' || test \$? = 1; } | { xargs docker rmi || test \$? = 123; }", - label: 'Clean old Docker images', - ) + // Add more info about job node sh ( script: './tests/scripts/task_show_node_info.sh', @@ -160,6 +156,23 @@ def init_git() { ) } +def docker_init(image) { + // Clear out all Docker images that aren't going to be used + sh( + script: """ + set -eux + docker image ls --all + IMAGES=\$(docker image ls --all --format '{{.Repository}}:{{.Tag}} {{.ID}}') + + echo -e "Found images:\\n\$IMAGES" + echo "\$IMAGES" | { grep -vE '${image}' || test \$? = 1; } | { xargs docker rmi || test \$? = 123; } + + docker image ls --all + """, + label: 'Clean old Docker images', + ) +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', @@ -321,6 +334,7 @@ def build_docker_images() { parallel 'ci-lint': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_lint') } @@ -328,6 +342,7 @@ def build_docker_images() { }, 'ci-cpu': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_cpu') } @@ -335,6 +350,7 @@ def build_docker_images() { }, 'ci-gpu': { node('GPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_gpu') } @@ -342,6 +358,7 @@ def build_docker_images() { }, 'ci-qemu': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_qemu') } @@ -349,6 +366,7 @@ def build_docker_images() { }, 'ci-i386': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_i386') } @@ -356,6 +374,7 @@ def build_docker_images() { }, 'ci-arm': { node('ARM') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_arm') } @@ -363,6 +382,7 @@ def build_docker_images() { }, 'ci-wasm': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_wasm') } @@ -370,6 +390,7 @@ def build_docker_images() { }, 'ci-hexagon': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_hexagon') } @@ -424,6 +445,7 @@ def lint() { 'Lint 1 of 2': { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/lint") { + docker_init(ci_lint) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -441,6 +463,7 @@ def lint() { 'Lint 2 of 2': { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/lint") { + docker_init(ci_lint) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -518,6 +541,7 @@ stage('Build') { if (!skip_ci) { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-gpu") { + docker_init(ci_gpu) init_git() sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" make("${ci_gpu} --no-gpu", 'build', '-j2') @@ -564,6 +588,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-cpu") { + docker_init(ci_cpu) init_git() sh ( script: "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build", @@ -603,6 +628,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-wasm") { + docker_init(ci_wasm) init_git() sh ( script: "${docker_run} ${ci_wasm} ./tests/scripts/task_config_build_wasm.sh build", @@ -627,6 +653,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-i386") { + docker_init(ci_i386) init_git() sh ( script: "${docker_run} ${ci_i386} ./tests/scripts/task_config_build_i386.sh build", @@ -660,6 +687,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-arm") { + docker_init(ci_arm) init_git() sh ( script: "${docker_run} ${ci_arm} ./tests/scripts/task_config_build_arm.sh build", @@ -691,6 +719,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-qemu") { + docker_init(ci_qemu) init_git() sh ( script: "${docker_run} ${ci_qemu} ./tests/scripts/task_config_build_qemu.sh build", @@ -721,6 +750,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-hexagon") { + docker_init(ci_hexagon) init_git() sh ( script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_config_build_hexagon.sh build", @@ -765,6 +795,7 @@ def shard_run_unittest_GPU_1_of_3() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -830,6 +861,7 @@ def shard_run_unittest_GPU_2_of_3() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -881,6 +913,7 @@ def shard_run_unittest_GPU_3_of_3() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -929,6 +962,7 @@ def shard_run_integration_CPU_1_of_6() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { try { + docker_init(ci_cpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -974,6 +1008,7 @@ def shard_run_integration_CPU_2_of_6() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { try { + docker_init(ci_cpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1019,6 +1054,7 @@ def shard_run_integration_CPU_3_of_6() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { try { + docker_init(ci_cpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1064,6 +1100,7 @@ def shard_run_integration_CPU_4_of_6() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { try { + docker_init(ci_cpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1109,6 +1146,7 @@ def shard_run_integration_CPU_5_of_6() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { try { + docker_init(ci_cpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1154,6 +1192,7 @@ def shard_run_integration_CPU_6_of_6() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { try { + docker_init(ci_cpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1200,6 +1239,7 @@ def shard_run_python_i386_1_of_5() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { + docker_init(ci_i386) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1246,6 +1286,7 @@ def shard_run_python_i386_2_of_5() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { + docker_init(ci_i386) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1291,6 +1332,7 @@ def shard_run_python_i386_3_of_5() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { + docker_init(ci_i386) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1336,6 +1378,7 @@ def shard_run_python_i386_4_of_5() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { + docker_init(ci_i386) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1381,6 +1424,7 @@ def shard_run_python_i386_5_of_5() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { try { + docker_init(ci_i386) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1427,6 +1471,7 @@ def shard_run_test_Hexagon_1_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1471,6 +1516,7 @@ def shard_run_test_Hexagon_2_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1514,6 +1560,7 @@ def shard_run_test_Hexagon_3_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1557,6 +1604,7 @@ def shard_run_test_Hexagon_4_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1600,6 +1648,7 @@ def shard_run_test_Hexagon_5_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1643,6 +1692,7 @@ def shard_run_test_Hexagon_6_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1686,6 +1736,7 @@ def shard_run_test_Hexagon_7_of_7() { node('CPU-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { try { + docker_init(ci_hexagon) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1730,6 +1781,7 @@ def shard_run_integration_aarch64_1_of_4() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1774,6 +1826,7 @@ def shard_run_integration_aarch64_2_of_4() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1818,6 +1871,7 @@ def shard_run_integration_aarch64_3_of_4() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1862,6 +1916,7 @@ def shard_run_integration_aarch64_4_of_4() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1907,6 +1962,7 @@ def shard_run_topi_GPU_1_of_4() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1950,6 +2006,7 @@ def shard_run_topi_GPU_2_of_4() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -1993,6 +2050,7 @@ def shard_run_topi_GPU_3_of_4() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2036,6 +2094,7 @@ def shard_run_topi_GPU_4_of_4() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2080,6 +2139,7 @@ def shard_run_frontend_GPU_1_of_6() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2123,6 +2183,7 @@ def shard_run_frontend_GPU_2_of_6() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2166,6 +2227,7 @@ def shard_run_frontend_GPU_3_of_6() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2209,6 +2271,7 @@ def shard_run_frontend_GPU_4_of_6() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2252,6 +2315,7 @@ def shard_run_frontend_GPU_5_of_6() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2295,6 +2359,7 @@ def shard_run_frontend_GPU_6_of_6() { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { try { + docker_init(ci_gpu) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2339,6 +2404,7 @@ def shard_run_topi_aarch64_1_of_2() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2387,6 +2453,7 @@ def shard_run_topi_aarch64_2_of_2() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2436,6 +2503,7 @@ def shard_run_frontend_aarch64_1_of_2() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2479,6 +2547,7 @@ def shard_run_frontend_aarch64_2_of_2() { node('ARM-SMALL') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-arm") { try { + docker_init(ci_arm) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -2648,6 +2717,7 @@ stage('Test') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-cpu") { timeout(time: max_time, unit: 'MINUTES') { try { + docker_init(ci_cpu) init_git() withEnv(['PLATFORM=cpu'], { sh( @@ -2692,6 +2762,7 @@ stage('Test') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-qemu") { timeout(time: max_time, unit: 'MINUTES') { try { + docker_init(ci_qemu) init_git() withEnv(['PLATFORM=qemu'], { sh( @@ -2736,6 +2807,7 @@ stage('Test') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-cpu") { timeout(time: max_time, unit: 'MINUTES') { try { + docker_init(ci_cpu) init_git() withEnv(['PLATFORM=cpu'], { sh( @@ -2773,6 +2845,7 @@ stage('Test') { if (!skip_ci) { node('GPU') { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/docs-python-gpu") { + docker_init(ci_gpu) init_git() sh( script: """ @@ -2814,8 +2887,7 @@ stage('Test') { }, ) } -} -/* +}/* stage('Build packages') { parallel 'conda CPU': { node('CPU') { diff --git a/jenkins/Build.groovy.j2 b/jenkins/Build.groovy.j2 index 4b0b4ae2e2c80..62ccc94916048 100644 --- a/jenkins/Build.groovy.j2 +++ b/jenkins/Build.groovy.j2 @@ -62,6 +62,7 @@ stage('Build') { if (!skip_ci) { node('CPU-SMALL') { ws({{ m.per_exec_ws('tvm/build-gpu') }}) { + docker_init(ci_gpu) init_git() sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" make("${ci_gpu} --no-gpu", 'build', '-j2') @@ -79,6 +80,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws({{ m.per_exec_ws('tvm/build-cpu') }}) { + docker_init(ci_cpu) init_git() sh ( script: "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build", @@ -102,6 +104,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws({{ m.per_exec_ws('tvm/build-wasm') }}) { + docker_init(ci_wasm) init_git() sh ( script: "${docker_run} ${ci_wasm} ./tests/scripts/task_config_build_wasm.sh build", @@ -126,6 +129,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws({{ m.per_exec_ws('tvm/build-i386') }}) { + docker_init(ci_i386) init_git() sh ( script: "${docker_run} ${ci_i386} ./tests/scripts/task_config_build_i386.sh build", @@ -143,6 +147,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('ARM-SMALL') { ws({{ m.per_exec_ws('tvm/build-arm') }}) { + docker_init(ci_arm) init_git() sh ( script: "${docker_run} ${ci_arm} ./tests/scripts/task_config_build_arm.sh build", @@ -160,6 +165,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws({{ m.per_exec_ws('tvm/build-qemu') }}) { + docker_init(ci_qemu) init_git() sh ( script: "${docker_run} ${ci_qemu} ./tests/scripts/task_config_build_qemu.sh build", @@ -177,6 +183,7 @@ stage('Build') { if (!skip_ci && is_docs_only_build != 1) { node('CPU-SMALL') { ws({{ m.per_exec_ws('tvm/build-hexagon') }}) { + docker_init(ci_hexagon) init_git() sh ( script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_config_build_hexagon.sh build", diff --git a/jenkins/DockerBuild.groovy.j2 b/jenkins/DockerBuild.groovy.j2 index 84bb8e3e376d1..e9d80801a9d9c 100644 --- a/jenkins/DockerBuild.groovy.j2 +++ b/jenkins/DockerBuild.groovy.j2 @@ -59,6 +59,7 @@ def build_docker_images() { parallel 'ci-lint': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_lint') } @@ -66,6 +67,7 @@ def build_docker_images() { }, 'ci-cpu': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_cpu') } @@ -73,6 +75,7 @@ def build_docker_images() { }, 'ci-gpu': { node('GPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_gpu') } @@ -80,6 +83,7 @@ def build_docker_images() { }, 'ci-qemu': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_qemu') } @@ -87,6 +91,7 @@ def build_docker_images() { }, 'ci-i386': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_i386') } @@ -94,6 +99,7 @@ def build_docker_images() { }, 'ci-arm': { node('ARM') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_arm') } @@ -101,6 +107,7 @@ def build_docker_images() { }, 'ci-wasm': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_wasm') } @@ -108,6 +115,7 @@ def build_docker_images() { }, 'ci-hexagon': { node('CPU') { timeout(time: max_time, unit: 'MINUTES') { + docker_init('none') init_git() build_image('ci_hexagon') } diff --git a/jenkins/Lint.groovy.j2 b/jenkins/Lint.groovy.j2 index 61c13cd407d02..40dad3aef7be3 100644 --- a/jenkins/Lint.groovy.j2 +++ b/jenkins/Lint.groovy.j2 @@ -6,6 +6,7 @@ def lint() { num_shards=2, node='CPU-SMALL', ws='tvm/lint', + docker_image='ci_lint', ) %} sh ( diff --git a/jenkins/Prepare.groovy.j2 b/jenkins/Prepare.groovy.j2 index b4db7de63bd15..2900775f49452 100644 --- a/jenkins/Prepare.groovy.j2 +++ b/jenkins/Prepare.groovy.j2 @@ -6,11 +6,7 @@ def per_exec_ws(folder) { def init_git() { checkout scm - // Clear out all Docker images that aren't going to be used - sh( - script: "docker image ls --all --format {% raw %}'{{.Repository}}:{{.Tag}} {{.ID}}'{% endraw %} | { grep -vE '{% for image in images %}{% raw %}${{% endraw %}{{ image.name }}{% raw %}}{% endraw %}{% if not loop.last %}|{% endif %}{% endfor %}' || test \$? = 1; } | { xargs docker rmi || test \$? = 123; }", - label: 'Clean old Docker images', - ) + // Add more info about job node sh ( script: './tests/scripts/task_show_node_info.sh', @@ -58,6 +54,23 @@ def init_git() { ) } +def docker_init(image) { + // Clear out all Docker images that aren't going to be used + sh( + script: """ + set -eux + docker image ls --all + IMAGES=\$(docker image ls --all --format {% raw %}'{{.Repository}}:{{.Tag}} {{.ID}}'{% endraw %}) + + echo -e "Found images:\\n\$IMAGES" + echo "\$IMAGES" | { grep -vE '${image}' || test \$? = 1; } | { xargs docker rmi || test \$? = 123; } + + docker image ls --all + """, + label: 'Clean old Docker images', + ) +} + def should_skip_slow_tests(pr_number) { withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', diff --git a/jenkins/Test.groovy.j2 b/jenkins/Test.groovy.j2 index a08c50905a056..9f949ae717c2a 100644 --- a/jenkins/Test.groovy.j2 +++ b/jenkins/Test.groovy.j2 @@ -10,6 +10,7 @@ node="GPU", ws="tvm/ut-python-gpu", platform="gpu", + docker_image="ci_gpu", test_method_names=test_method_names, ) %} {% if shard_index == 1 %} @@ -44,6 +45,7 @@ num_shards=6, ws="tvm/integration-python-cpu", platform="cpu", + docker_image="ci_cpu", test_method_names=test_method_names, ) %} {{ m.download_artifacts(tag='cpu', filenames=tvm_multilib_tsim) }} @@ -59,6 +61,7 @@ num_shards=5, ws="tvm/integration-python-i386", platform="i386", + docker_image="ci_i386", test_method_names=test_method_names, ) %} {{ m.download_artifacts(tag='i386', filenames=tvm_multilib) }} @@ -78,6 +81,7 @@ node="CPU-SMALL", ws="tvm/test-hexagon", platform="hexagon", + docker_image="ci_hexagon", test_method_names=test_method_names, num_shards=7, ) %} @@ -98,6 +102,7 @@ node="ARM-SMALL", ws="tvm/ut-python-arm", platform="arm", + docker_image="ci_arm", test_method_names=test_method_names, ) %} {{ m.download_artifacts(tag='arm', filenames=tvm_multilib) }} @@ -114,6 +119,7 @@ num_shards=4, ws="tvm/topi-python-gpu", platform="gpu", + docker_image="ci_gpu", test_method_names=test_method_names, ) %} {{ m.download_artifacts(tag='gpu', filenames=tvm_multilib) }} @@ -129,6 +135,7 @@ num_shards=6, ws="tvm/frontend-python-gpu", platform="gpu", + docker_image="ci_gpu", test_method_names=test_method_names, ) %} {{ m.download_artifacts(tag='gpu', filenames=tvm_multilib) }} @@ -143,6 +150,7 @@ node="ARM-SMALL", ws="tvm/ut-python-arm", platform="arm", + docker_image="ci_arm", num_shards=2, test_method_names=test_method_names, ) %} @@ -163,6 +171,7 @@ node="ARM-SMALL", ws="tvm/frontend-python-arm", platform="arm", + docker_image="ci_arm", num_shards=2, test_method_names=test_method_names, ) %} @@ -191,6 +200,7 @@ stage('Test') { node="CPU-SMALL", ws="tvm/ut-python-cpu", platform="cpu", + docker_image="ci_cpu", ) %} {{ m.download_artifacts(tag='cpu', filenames=tvm_multilib_tsim) }} ci_setup(ci_cpu) @@ -207,6 +217,7 @@ stage('Test') { node="CPU-SMALL", ws="tvm/test-qemu", platform="qemu", + docker_image="ci_qemu", ) %} {{ m.download_artifacts(tag='qemu', filenames=tvm_lib, folders=microtvm_template_projects) }} add_microtvm_permissions() @@ -226,6 +237,7 @@ stage('Test') { node="CPU-SMALL", ws="tvm/frontend-python-cpu", platform="cpu", + docker_image="ci_cpu", ) %} {{ m.download_artifacts(tag='cpu', filenames=tvm_multilib) }} ci_setup(ci_cpu) @@ -238,6 +250,7 @@ stage('Test') { if (!skip_ci) { node('GPU') { ws({{ m.per_exec_ws('tvm/docs-python-gpu') }}) { + docker_init(ci_gpu) init_git() {{ m.download_artifacts(tag='gpu', filenames=tvm_multilib, folders=microtvm_template_projects) }} add_microtvm_permissions() @@ -256,4 +269,4 @@ stage('Test') { }, ) } -} +} \ No newline at end of file diff --git a/jenkins/macros.j2 b/jenkins/macros.j2 index 1c649e31fabfd..5a641b73fea84 100644 --- a/jenkins/macros.j2 +++ b/jenkins/macros.j2 @@ -19,7 +19,7 @@ "workspace/exec_${env.EXECUTOR_NUMBER}/{{ folder }}" {%- endmacro -%} -{% macro sharded_test_step(name, num_shards, node, ws, platform, test_method_names) %} +{% macro sharded_test_step(name, num_shards, node, ws, docker_image, platform, test_method_names) %} {% for shard_index in range(1, num_shards + 1) %} {% set method_name = "shard_run_" + name.replace(":", "").replace(" ", "-").replace("-", "_") + "_" + shard_index|string + "_of_" + num_shards|string %} @@ -28,6 +28,7 @@ def {{ method_name }}() { node('{{ node }}') { ws({{ per_exec_ws(ws) }}) { try { + docker_init({{ docker_image }}) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -51,11 +52,12 @@ def {{ method_name }}() { {% endfor %} {% endmacro %} -{% macro sharded_lint_step(name, num_shards, node, ws) %} +{% macro sharded_lint_step(name, num_shards, docker_image, node, ws) %} {% for shard_index in range(1, num_shards + 1) %} '{{ name }} {{ shard_index }} of {{ num_shards }}': { node('{{ node }}') { ws({{ per_exec_ws(ws) }}) { + docker_init({{ docker_image }}) init_git() timeout(time: max_time, unit: 'MINUTES') { withEnv([ @@ -71,13 +73,14 @@ def {{ method_name }}() { {% endmacro %} -{% macro test_step(name, node, ws, platform) %} +{% macro test_step(name, node, ws, docker_image, platform) %} '{{ name }}': { if (!skip_ci && is_docs_only_build != 1) { node('{{ node }}') { ws({{ per_exec_ws(ws) }}) { timeout(time: max_time, unit: 'MINUTES') { try { + docker_init({{ docker_image }}) init_git() withEnv(['PLATFORM={{ platform }}'], { {{ caller() | indent(width=12) | trim }} From c1b22eefb5dc5c00d945a4cae6c91ce078afcc7d Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 1 Jun 2022 02:50:00 +0800 Subject: [PATCH 002/181] [Arith] Merge surjective/non-surjective iter mapping detections (#11287) * simplify (x * 96) % 64 to (x * 32) % 64 * adapt merge mulmod opt for OffsetOf computation * merge DetectIterMap and DetectIterMapPadded * adjust related interfaces for IterMapLevel * - check incompatible left paddings - determine case like x % 16, x in [0, 5) to be non-surjective, since usages may treat the region extent as 16 by mistake. - skip second round of rewrite when there is no padding - fix some typo in comments * rebase upstream --- include/tvm/arith/iter_affine_map.h | 114 ++- python/tvm/arith/iter_affine_map.py | 53 +- src/arith/int_set.cc | 5 +- src/arith/iter_affine_map.cc | 490 +++++++------ src/arith/pattern_match.h | 2 + src/arith/rewrite_simplify.cc | 72 +- src/arith/rewrite_simplify.h | 2 + src/tir/ir/buffer.cc | 17 +- src/tir/ir/index_map.cc | 23 +- src/tir/schedule/analysis/analysis.cc | 8 +- src/tir/schedule/analysis/layout.cc | 11 +- .../schedule/primitive/blockize_tensorize.cc | 7 +- src/tir/schedule/primitive/compute_at.cc | 2 +- src/tir/schedule/primitive/compute_inline.cc | 5 +- .../primitive/layout_transformation.cc | 7 +- .../schedule/primitive/loop_transformation.cc | 2 +- .../unittest/test_arith_iter_affine_map.py | 674 ++++++++++-------- .../unittest/test_arith_rewrite_simplify.py | 14 +- tests/python/unittest/test_tir_buffer.py | 14 +- .../unittest/test_tir_schedule_compute_at.py | 38 + 20 files changed, 871 insertions(+), 689 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 4cf6f086d1ed3..2c0e5e92997af 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); }; +/*! \brief Mapping level for iterators. */ +enum IterMapLevel { + // Require the mapping to be bijective. + Bijective = 0, + // Require the mapping to be surjective. + Surjective = 1, + // No mapping safety check. + NoCheck = 3 +}; + /*! - * \brief Detect if indices can be written as - * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] - * - * Here y = some-quasi-affine-iter-map(input_iters) - * and c are symbolic constants. - * - * We also requires that y_i and y_j to be independent for i != j. - * - * For returned value rv, the following is always true: - * - rv[i]->args.size() <=1: only one iterator per element. - * - * \param indices The indices to detect pattern for. - * \param input_iters Map from variable to iterator's range. - * \param predicate The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. - * \param analyzer Analyzer used to get context information. - * \param simplify_trivial_iterators If true, iterators with extent of - * 1 will be replaced with a constant value. - * - * \return The detected pattern if a match exists, - * otherwise return an empty array. + * \brief Result of DetectIterMap. */ -Array DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +class IterMapResultNode : public Object { + public: + // The detected pattern if a match exists. + Array indices; -/*! \brief A utility struct for return values from DetectPaddedIterMap - */ -struct PaddedIterMapResult { // Any errors that occurred while converting the input indices. If // the array is empty, the conversion was successful. Array errors; - // The detected pattern if a match exists. - Array indices; - - /* \brief Boolean expression indicating if padding was required - * - * `requires_padding` evaluates to true if the returned indices - * contain padding relative to the provided expressions, and false - * otherwise. If `input_iters` contains a variable extent, this - * expression may be in terms of those variables. - */ - PrimExpr requires_padding; - - /* \brief Boolean expression indicating if a specific value w + /*! \brief Boolean expression indicating if a specific value w * * `padding_predicate` evaluates to true for a set of indices that * are outside the bounds of the provided index iterators, but @@ -314,43 +290,57 @@ struct PaddedIterMapResult { * `input_iters`. */ PrimExpr padding_predicate; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("errors", &errors); + v->Visit("indices", &indices); + v->Visit("padding_predicate", &padding_predicate); + } + + static constexpr const char* _type_key = "arith.IterMapResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object); +}; + +/*! + * \brief Managed reference to IterMapResultNode. + * \sa IterMapResultNode + */ +class IterMapResult : public ObjectRef { + public: + // constructor + IterMapResult() { data_ = make_object(); } + + /*! \return mutable pointers to the node. */ + IterMapResultNode* operator->() const { return static_cast(get_mutable()); } }; /*! * \brief Detect if indices can be written as * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] * - * Here y = some-quasi-affine-iter-map(input_iters) and c are - * symbolic constants. The y_i iterators may be padded to fit this - * representation. + * Here y = some-quasi-affine-iter-map(input_iters) + * and c are symbolic constants. * * We also requires that y_i and y_j to be independent for i != j. * * For returned value rv, the following is always true: - * - rv.indices[i]->args.size() <=1: only one iterator per element. + * - rv[i]->args.size() <=1: only one iterator per element. * * \param indices The indices to detect pattern for. - * * \param input_iters Map from variable to iterator's range. - * * \param predicate The predicate constraints on the input iterators - * - * \param require_bijective A boolean flag that indicates whether the - * mapping should be bijective. If true, no padding may be - * introduced. - * + * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. - * * \param simplify_trivial_iterators If true, iterators with extent of * 1 will be replaced with a constant value. * - * \return An instance of PaddedIterMapResult. + * \return The detected iteration result. + * The return object's .indices is empty on failure. */ -PaddedIterMapResult DetectPaddedIterMap(const Array& indices, - const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, - bool simplify_trivial_iterators = true); +IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices @@ -358,12 +348,12 @@ PaddedIterMapResult DetectPaddedIterMap(const Array& indices, * \param indices The indices to detect pattern for. * \param input_iters Map from variable to iterator's range. * \param input_pred The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * \param check_level The iter mapping checking level. * * \return The indices after rewrite */ Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool require_bijective); + const PrimExpr& input_pred, IterMapLevel check_level); /*! * \brief Apply the inverse of the affine transformation to the outputs. @@ -403,7 +393,7 @@ Map InverseAffineIterMap(const Array& iter_map, * \param input_iters Map from variable to iterator's range. * \param sub_iters Iterators of subspace. * \param predicate The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. * * \return The result list has length len(bindings) + 1 @@ -416,7 +406,7 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer); + IterMapLevel check_level, arith::Analyzer* analyzer); /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 2be939a12277c..77d6f418b8537 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """ Iterator (quasi)affine mapping patterns.""" +from enum import IntEnum import tvm._ffi from tvm.runtime import Object from tvm.ir import PrimExpr @@ -88,11 +89,35 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) +class IterMapLevel(IntEnum): + """Possible kinds of iter mapping check level.""" + + Bijective = 0 + Surjective = 1 + NoCheck = 3 + + @staticmethod + def from_str(name: str): + """Helper to create level enum from string""" + if name is None: + return IterMapLevel.NoCheck + name = name.lower() + if name == "bijective": + check_level = IterMapLevel.Bijective + elif name == "surjective": + check_level = IterMapLevel.Surjective + elif name == "nocheck": + check_level = IterMapLevel.NoCheck + else: + raise ValueError(f"Unknown check level {name}") + return check_level + + def detect_iter_map( indices, input_iters, predicate=True, - require_bijective=False, + check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True, ): """Detect if indices can be written as mapped iters from input iters @@ -108,8 +133,8 @@ def detect_iter_map( predicate : PrimExpr The predicate constraints on the input iterators - require_bijective : bool - A boolean flag that indicates whether the mapping should be bijective + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping simplify_trivial_iterators: bool If true, iterators with extent of 1 will be replaced with a @@ -117,13 +142,17 @@ def detect_iter_map( Returns ------- - results : List[IterSumExpr] + results : IterMapResult The iter map matching result. - Empty array if no match can be found. + The result's .indices is empty array if no match can be found. """ + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + elif check_level is None: + check_level = IterMapLevel.NoCheck return _ffi_api.DetectIterMap( - indices, input_iters, predicate, require_bijective, simplify_trivial_iterators + indices, input_iters, predicate, check_level, simplify_trivial_iterators ) @@ -143,7 +172,9 @@ def normalize_iter_map_to_expr(expr): return _ffi_api.NormalizeIterMapToExpr(expr) -def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False): +def subspace_divide( + bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective +): """Detect if bindings can be written as [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters) @@ -172,8 +203,8 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi predicate : PrimExpr The predicate constraints on the input iterators - require_bijective : bool - A boolean flag that indicates whether the bindings should be bijective + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping Returns ------- @@ -185,7 +216,9 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi len(bindings): the predicate of outer space and inner space Empty array if no match can be found. """ - return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level) def inverse_affine_iter_map(iter_map, outputs): diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a3fa879afa270..48fae479b042b 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -867,9 +867,10 @@ Optional> EstimateRegionLowerBound(const Array& region, for (const Range& range : region) { affine_indices.push_back(range->min); } - iter_sum_exprs = DetectIterMap( + auto res = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, - /*predicate=*/predicate, /*require_bijective=*/false, analyzer); + /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); + iter_sum_exprs = res->indices; } if (iter_sum_exprs.empty()) { return NullOpt; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 9fad3b2816a12..cce826fedca64 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -178,10 +178,7 @@ class IterMapRewriter : public ExprMutator { explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, bool simplify_trivial_iterators, Array* errors) - : analyzer_(analyzer), - errors_(*errors), - requires_padding_(const_false()), - padding_predicate_(const_false()) { + : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -202,16 +199,17 @@ class IterMapRewriter : public ExprMutator { } PrimExpr padding_predicate() const { return padding_predicate_; } - PrimExpr requires_padding() const { return requires_padding_; } + bool requires_padding() const { return requires_padding_; } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - void UpdatePadding(const PrimExpr& expr) { + IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) { update_iterator_padding_ = true; - DirectMutate(expr); + auto res = Rewrite(expr); update_iterator_padding_ = false; + return res; } IterSumExpr RewriteIterConstraint(const PrimExpr& expr, @@ -222,7 +220,7 @@ class IterMapRewriter : public ExprMutator { } /*! - * \brief If require_bijective is true, this function checks two conditions: + * \brief If require bijective mapping, this function checks two conditions: * - C0: Each iter mark should be fully covered by non-overlapping splits. * - C1: All of the input iterators are used. * Example: given x in [0, 8) y in [0, 6) @@ -232,7 +230,7 @@ class IterMapRewriter : public ExprMutator { * contribute two non-overlapping splits that covers x. * - bindings = [x / 4, x % 4] won't pass because y is not used. * - * If require_bijective is false, this function checks one condition: + * If only require surjective mapping, this function checks one condition: * - C0: Each iter mark has a chance to be fully covered by non-overlapping splits. * Example: given x in [0, 8) y in [0, 6) * - bindings = [x / 4] will pass because x / 4 can be one split of x @@ -241,7 +239,7 @@ class IterMapRewriter : public ExprMutator { * - bindings = [x / 3] will not pass because x / 3 can not be one split of x * \return whether the bindings are valid */ - bool CheckMapping(const Array& bindings, bool require_bijective) { + bool CheckMapping(const Array& bindings, IterMapLevel check_level) { IterMarkSplitCollector collector; // We can check that for each iter mark: // All the splits that refers to the iter_mark covers its extent. @@ -249,11 +247,11 @@ class IterMapRewriter : public ExprMutator { collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { - if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { + if (TryNormalizeSplits(mark, collector.mark2splits_[mark], check_level).empty()) { return false; } } - if (require_bijective) { + if (check_level == IterMapLevel::Bijective) { // all input marks must be visited for (const IterMark& mark : input_marks_) { if (collector.visited_.count(mark) == 0 && !is_one(mark->extent)) { @@ -375,13 +373,14 @@ class IterMapRewriter : public ExprMutator { }; struct IterPaddingInfo { - // Used and collected during first pass - std::vector divisors; + // GCD of padding factor collected during first pass + PrimExpr padding_factor{1}; + + PrimExpr left_pad{0}; + PrimExpr right_pad{0}; - // Defined on first encounter in second pass - IterSplitExpr padded; - PrimExpr left_pad; - PrimExpr right_pad; + // Padded form of original iter mark + IterMark padded; }; // temp hash for de-duplication purposes. @@ -427,41 +426,30 @@ class IterMapRewriter : public ExprMutator { // input iter marks std::vector input_marks_; - // Map from a normal PrimExpr to the padded iterator information for + // Map from an iter mark to the padded iterator information for // it. This is necessary for introducing the same padding in all // usage of an input iterator. (e.g. (i-1) occurring in the // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be // left-padded by 31 for each occurrence.) - std::unordered_map padded_iter_map_; + std::unordered_map padded_iter_map_; + + // Map from padded iter mark to it's origin mark + std::unordered_map padded_origin_map_; - /* If allow_padding_ is true, allow the extents of the IterMap to be + /* If update_iterator_padding_ is true, allow the extents of the IterMap to be * padded beyond the original iterators. * - * For example, if allow_padding_ is true, the expressions i//4 and + * For example, if update_iterator_padding_ is true, the expressions i//4 and * i%4, where i is on the range [0,18), would be represented as * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4). - * This representation would be forbidden if allow_padding_ is false, + * This representation would be forbidden if update_iterator_padding_ is false, * because lower_factor=4 does not evenly divide the original extent of * 18. */ bool update_iterator_padding_{false}; - /* A boolean expression that is true if any padding has been introduced - * by the transformation, and false otherwise. - * - * Example: [i//4, i%4], i in range [0,16) - * requires_padding_ will be false - * - * Example: [i//4, i%4], i in range [0,18) - * requires_padding_ will be true - * - * Example: [i//4, i%4], i in range [0,N) - * requires_padding_ will be the expression N%4==0 - */ - PrimExpr requires_padding_; - /* A boolean expression that is true for any padding that has been - * introduced, and false otherwise. If allow_padding_ is false, + * introduced, and false otherwise. If update_iterator_padding_ is false, * padding_predicate_ will always be false. * * Example: [i//4, i%4], i in range [0,16) @@ -475,6 +463,11 @@ class IterMapRewriter : public ExprMutator { */ PrimExpr padding_predicate_; + /* A boolean flag denotes there are padding iterations detected + * in the first round of indices rewriting. + */ + bool requires_padding_{false}; + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly // an extra offset) // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) @@ -538,13 +531,12 @@ class IterMapRewriter : public ExprMutator { * If not, return an empty array. * \param mark The iterator of interest. * \param splits The splits to be verified. - * \param require_bijective A boolean flag that indicates whether the bindings should be - * bijective. + * \param check_level Iteration mapping's check level. * \return The normalized splits. */ Array TryNormalizeSplits(const IterMark& mark, const std::vector& splits, - bool require_bijective) { + IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); @@ -559,7 +551,7 @@ class IterMapRewriter : public ExprMutator { } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective - if (require_bijective) { + if (check_level == IterMapLevel::Bijective) { return Array(); } // look for the next split skipping this lower factor @@ -578,18 +570,64 @@ class IterMapRewriter : public ExprMutator { expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; } + // Extract iteration mark info before padding + auto pad_mark_it = padded_origin_map_.find(mark); + bool has_padding = pad_mark_it != padded_origin_map_.end(); + + bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent); + bool match_iter_divisor = + match_full_iter || CanProveDivisible(mark->extent, expected_lower_factor); + // Case 1. bijective is required. - // We check the extent we calculate is consistent with the extent of the mark - // Case 2. bijective is not required. + // We check the extent we calculate is consistent with the extent of the mark and + // iteration mark's padding is not allowed. + // + // Case 2. bijective is not required and there is no padding. // We check the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - if (require_bijective) { - if (!analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) { - return Array(); + // + // Case 3. bijective is not required and there exists padding. We check either + // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is + // the single split for the iter mark. + // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective + // according to how we pad the original iteration mark. + // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent + // before padding is greater or equal than the extent we calculate. + // For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24. + // + if (check_level == IterMapLevel::Bijective) { + if (has_padding) { + ErrorLogger(this) << "Bijectvie mapping should not take iter paddings"; + return {}; + } else if (!match_full_iter) { + ErrorLogger(this) << "The iterations do not traverse full iter space"; + return {}; } - } else { - if (!CanProveDivisible(mark->extent, expected_lower_factor)) { - return Array(); + } else if (!has_padding) { + if (!match_iter_divisor) { + ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + return {}; + } + } else if (check_level == IterMapLevel::Surjective) { + PrimExpr extent_before_padding = pad_mark_it->second->extent; + if (match_full_iter) { + if (splits.size() != 1) { + ErrorLogger(this) << "Dependent iterations on padding iter space"; + return Array(); + } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) && + !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { + ErrorLogger(this) << "Split on padding iteration is not surjective " + << "if the split extent equals to the full iter space extent"; + return Array(); + } + } else if (match_iter_divisor) { + if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { + ErrorLogger(this) << "The extent before padding is less than lower factor"; + return Array(); + } + } else { + ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + return {}; } } return Array(iters.rbegin(), iters.rend()); @@ -1018,39 +1056,23 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { return true; } -Array DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, bool simplify_trivial_iterators) { - auto padded_result = DetectPaddedIterMap(indices, input_iters, predicate, require_bijective, - analyzer, simplify_trivial_iterators); - if (padded_result.errors.size()) { - return Array(); - } - if (!analyzer->CanProve(!padded_result.requires_padding)) { - return Array(); - } - return padded_result.indices; -} - -PaddedIterMapResult DetectPaddedIterMap(const Array& indices, - const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, - bool simplify_trivial_iterators) { - PaddedIterMapResult result; +IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, bool simplify_trivial_iterators) { + IterMapResult result; // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. if (!IterRangeSanityCheck(input_iters)) { - result.errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); + result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); return result; } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { - result.errors.push_back("Could not parse predicate as constraints on the input iterators."); + result->errors.push_back("Could not parse predicate as constraints on the input iterators."); return result; } // We have to make sure when we visit an iterator, all the constraints related with its successors @@ -1065,58 +1087,65 @@ PaddedIterMapResult DetectPaddedIterMap(const Array& indices, [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; }); IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators, - &result.errors); + &result->errors); // Step0.0: rewrite constraints in the order from size-small ones to size-big ones for (const IterConstraint& constraint : constraints) { auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound); - if (result.errors.size()) { + if (result->errors.size() > 0) { return result; } } if (!rewriter.CheckConstraints()) { - result.errors.push_back("Invalid constraints."); + result->errors.push_back("Invalid constraints."); return result; } - // Step0.1: Check each index to determine required padding - bool allow_padding = !require_bijective; + // Step0.1: Rewrite indicies and determine required padding, + // if there is no padding, it should be the final result. + Array rewrite_indices; + rewrite_indices.reserve(indices.size()); + bool allow_padding = check_level != IterMapLevel::Bijective; if (allow_padding) { for (PrimExpr value : indices) { - rewriter.UpdatePadding(value); + rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value)); + if (result->errors.size() > 0) { + return result; + } } } - // Step0.2: rewrite indices - for (PrimExpr value : indices) { - result.indices.push_back(rewriter.Rewrite(value)); - if (result.errors.size()) { - return result; + // Step0.2: Rewrite indices in the second round. + if (!allow_padding || rewriter.requires_padding()) { + rewrite_indices.clear(); + for (PrimExpr value : indices) { + rewrite_indices.push_back(rewriter.Rewrite(value)); + if (result->errors.size() > 0) { + return result; + } } } - - result.requires_padding = rewriter.requires_padding(); - result.padding_predicate = rewriter.padding_predicate(); + result->padding_predicate = rewriter.padding_predicate(); // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(result.indices, require_bijective)) { - if (require_bijective) { - result.errors.push_back("Index mapping does not form a bijective transform."); + if (!rewriter.CheckMapping(rewrite_indices, check_level)) { + if (check_level == IterMapLevel::Bijective) { + result->errors.push_back("Index mapping does not form a bijective transform."); } else { - result.errors.push_back("Mapped indices are not independent."); + result->errors.push_back("Mapped indices are not independent."); } return result; } - + result->indices = rewrite_indices; return result; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") .set_body_typed([](const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool is_bijective, + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, + return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, simplify_trivial_iterators); }); @@ -1246,15 +1275,17 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o auto split = Downcast(dividend); return IterSumExpr({split}, make_zero(split.dtype())); } else if (dividend->IsInstance()) { - auto opt_fused = TryFuseIters(Downcast(dividend)); + auto sum = Downcast(dividend); + if (sum->args.size() <= 1) { + return sum; + } + auto opt_fused = TryFuseIters(sum); if (!opt_fused) { ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) << ", can't be written as a single fused IterSum"; return IterSumExpr(); } - IterSumExpr fused = opt_fused.value(); - ICHECK_EQ(fused->args.size(), 1U); return fused; } else { @@ -1263,140 +1294,159 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } } +/*! \brief Find approximate least common multiplier. */ +PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { + auto fsplit = [](const PrimExpr& e) -> std::pair { + if (const IntImmNode* imm = e.as()) { + return {1, imm->value}; + } + PVar pv; + PVar pc; + if ((pv * pc).Match(e) || (pc * pv).Match(e)) { + return {pv.Eval(), pc.Eval()->value}; + } else { + return {e, 1}; + } + }; + auto p1 = fsplit(a); + auto p2 = fsplit(b); + auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); + if (analyzer->CanProveEqual(p1.first, p2.first)) { + return p1.first * const_lcm; + } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) { + return p1.first * const_lcm; + } else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) { + return p2.first * const_lcm; + } else { + return (p1.first * p2.first) * const_lcm; + } +} + std::pair IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExpr base, PrimExpr divisor) { // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor // If FloorMod: (((source//lower_factor) % extent) + base) % divisor - PrimExpr lookup_key = split; - - auto modified_divisor = [&]() { - if (update_iterator_padding_) { - return divisor; - } - - auto it = padded_iter_map_.find(lookup_key); - if (it == padded_iter_map_.end()) { - return divisor; - } - - const std::vector& divisors = it->second.divisors; - PrimExpr largest_divisor = divisor; - for (const auto& other : divisors) { - if (CanProveDivisible(other, largest_divisor)) { - // New one is bigger, use it - largest_divisor = other; - } else if (CanProveDivisible(largest_divisor, other)) { - // Current is bigger, keep it - } else { - ErrorLogger(this) << "Iterator appears in multiple terms with incompatible divisors " - << tvm::PrettyPrint(largest_divisor) << " and " - << tvm::PrettyPrint(other); - } - } - return largest_divisor; - }(); - - divisor = modified_divisor; - // First, adding any padding that is on the lower side of a - // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0 - // when iter==0. - - PrimExpr left_pad; - - if (is_zero(base)) { - // Padding on the left is unnecessary if base is known to be zero. - left_pad = make_zero(base->dtype); - } else { - left_pad = analyzer_->Simplify(floormod(base, divisor)); - } + // FloorDiv/FloorMod, such that floormod(split - left_pad, divisor) == 0 + // when iter == 0. + PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor)); // Next, adding any padding that is on the upper side of a - // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, divisor) == 0 - // when iter==extent. - + // FloorDiv/FloorMod, such that floormod(left_pad + split + right_pad, divisor) == 0 + // when iter == extent. PrimExpr right_edge = left_pad + split->extent; PrimExpr right_pad; - if (CanProveDivisible(right_edge, divisor)) { - // Padding on the right is unnecessary if the extent is a multiple of - // the divisor. right_pad = 0; } else { right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); } - if (is_zero(left_pad) && is_zero(right_pad)) { - return {split, left_pad}; - } - + const IterMark& mark = split->source; if (update_iterator_padding_) { // In the first pass, the primary goal is to collect all the divisors - // that may be used for padding. These will impact the divisor used - // to determine padding in the second pass. - IterPaddingInfo& info = padded_iter_map_[lookup_key]; - - info.divisors.push_back(divisor); - - PrimExpr padded_extent = left_pad + split->extent + right_pad; - - IterSumExpr as_sum({split}, left_pad); - IterMark mark(as_sum, padded_extent); - IterSplitExpr new_split(mark); - - return {new_split, left_pad}; + // that may be used for padding. These will impact the divisor used + // to determine padding in the second pass. We try add padding to + // split's source iteraton mark thus all splits under the same mark will + // share the same padded source iteration. + auto& info = padded_iter_map_[mark]; + info.padding_factor = + ApproxLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_); + + // If the split itself require no padding, return directly. + if (is_zero(left_pad) && is_zero(right_pad)) { + return {split, 0}; + } + + // Update padding requirement on the lower side of the source iter mark. + // In the second pass, all splits would check whether the maximum left pading + // on the iter mark is compatible with it's own left padding. + requires_padding_ = true; + PrimExpr mark_left_pad = left_pad * split->lower_factor; + info.left_pad = max(info.left_pad, mark_left_pad); + + // Since we only care the extent in the first pass's result + // we just create result of compatible padded extent, ignoring + // possible relations between different padded iters. + PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad); + split.CopyOnWrite()->extent = padded_extent; + return {split, left_pad}; } - // Any padding that is required during parsing should have been found - // during the first pass that determines the GCD. - auto it = padded_iter_map_.find(lookup_key); + // In the second pass, update iteration mark's to padded form + auto it = padded_iter_map_.find(mark); if (it == padded_iter_map_.end()) { - ErrorLogger(this) << "Dividend has extent " << tvm::PrettyPrint(split->extent) << " and offset " - << tvm::PrettyPrint(base) << ", which requires padding for divisor " - << tvm::PrettyPrint(divisor) << "."; - return {IterSplitExpr(), left_pad}; + return {split, left_pad}; } - IterPaddingInfo& info = it->second; - - if (info.padded.defined()) { - // A previous visit already applied padding to this iterator. - // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`). - ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad)); - ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad)); - - return {info.padded, left_pad}; + auto& info = it->second; + if (is_zero(info.left_pad) && CanProveDivisible(mark->extent, info.padding_factor)) { + // the iter mark requires no padding + return {split, left_pad}; } - // This is the first encounter with the iterator during the second pass. - IterSumExpr as_sum({split}, left_pad); - IterMark mark(as_sum, left_pad + split->extent + right_pad); - info.padded = IterSplitExpr(mark); - info.left_pad = left_pad; - info.right_pad = right_pad; - - auto left_padding_introduced = (left_pad != 0); - // Equivalent to (0 <= split < left_pad), but easier to simplify in - // terms of the transformed variables. - auto left_padding_predicate = - left_padding_introduced && (floordiv(info.padded, divisor) == floordiv(base, divisor) && - floormod(info.padded, divisor) < left_pad); - - PrimExpr nparts = ceildiv(right_edge, divisor); - - auto right_padding_introduced = (right_pad != 0); - - // Equivalent to (right_edge <= split < right_edge+right_pad), but - // easier to simplify in terms of the transformed variables. - auto right_padding_predicate = right_padding_introduced && - (floordiv(info.padded, divisor) == floordiv(right_edge, divisor) && - floormod(info.padded, divisor) >= floormod(right_edge, divisor)); - - requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); - padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + // check that padding factor is compatible with current split and divisor + ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) + << "The padding factor " << info.padding_factor << " is not divisible by " + << split->lower_factor << " for the split " << split; + ICHECK(CanProveDivisible(info.padding_factor, divisor)) + << "The padding factor " << info.padding_factor << " is not divisible by " << divisor + << " for the split " << split; + + if (!info.padded.defined()) { + // the first time encounter the iter mark to pad, update the padded mark. + PrimExpr mark_left_pad = info.left_pad; + if (CanProveDivisible(mark_left_pad, split->lower_factor)) { + // correct current split's left padding + // (mark_left_pad + iter) // lower_factor % extent => + // (left_pad * lower_factor + mark) // lower_factor % extent => + // (left_pad + mark // lower_factor) % extent => + // left_pad + (mark // lower_factor % extent) => + // left_pad + split + // since the extent covers the full padding range. + left_pad = floordiv(mark_left_pad, split->lower_factor); + } else { + ErrorLogger(this) << "Detect incompatible left padding on " + << tvm::PrettyPrint(NormalizeIterMapToExpr(split)) + << ", the iter mark is left padded with " << mark_left_pad; + return {IterSplitExpr(), PrimExpr()}; + } - return {info.padded, left_pad}; + PrimExpr right_edge = mark->extent + mark_left_pad; + PrimExpr mark_right_pad; + if (CanProveDivisible(right_edge, info.padding_factor)) { + mark_right_pad = 0; + } else { + mark_right_pad = floormod(-right_edge, info.padding_factor); + } + PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad); + info.right_pad = mark_right_pad; + info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); + padded_origin_map_[info.padded] = mark; + + auto left_padding_introduced = (mark_left_pad != 0); + + // Equivalent to (0 <= split < left_pad), but easier to simplify in + // terms of the transformed variables. + auto left_padding_predicate = + left_padding_introduced && + (floordiv(info.padded->source, info.padding_factor) == 0 && + floormod(info.padded->source, info.padding_factor) < mark_left_pad); + auto right_padding_introduced = (mark_right_pad != 0); + + // Equivalent to (right_edge <= split < right_edge + right_pad), but + // easier to simplify in terms of the transformed variables. + auto right_padding_predicate = + right_padding_introduced && (floordiv(info.padded->source, info.padding_factor) == + floordiv(right_edge, info.padding_factor) && + floormod(info.padded->source, info.padding_factor) >= + floormod(right_edge, info.padding_factor)); + padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + } + split.CopyOnWrite()->source = info.padded; + split.CopyOnWrite()->extent = floordiv(info.padded->extent, split->lower_factor); + return {split, left_pad}; } PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { @@ -1462,7 +1512,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), /* scale = */ padded->scale); - auto new_base = floordiv(base - left_pad, rhs); + auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); if (is_zero(new_base)) { return std::move(new_split); } else { @@ -1540,7 +1590,6 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below // where x=floormod(floordiv(iter, lower_factor), extent) + base - auto pair = PadDividendToDivisor(lhs, base, rhs); IterSplitExpr padded = pair.first; if (!padded.defined()) { @@ -1671,19 +1720,20 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool require_bijective) { + const PrimExpr& input_pred, IterMapLevel check_level) { if (!IterRangeSanityCheck(input_iters)) return indices; Analyzer analyzer; - Array rewrite = - DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer); + Array rewrite = res->indices; + if (rewrite.empty()) { return indices; } - Array res; - res.reserve(rewrite.size()); + Array simplified; + simplified.reserve(rewrite.size()); IterMapToExprNormalizer converter(&analyzer); - for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); - return res; + for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); + return simplified; } /*! @@ -1963,10 +2013,10 @@ class SubspaceDivider { Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer) { + IterMapLevel check_level, arith::Analyzer* analyzer) { if (!IterRangeSanityCheck(input_iters)) return Array>(); - const Array& maps = - DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); + auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer); + const Array& maps = res->indices; if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -1993,10 +2043,10 @@ Array> SubspaceDivide(const Array& bindings, TVM_REGISTER_GLOBAL("arith.SubspaceDivide") .set_body_typed([](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective) { + const Array& sub_iters, const PrimExpr& predicate, int check_level) { arith::Analyzer ana; - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), + &ana); }); class InverseAffineIterMapTransformer { @@ -2128,5 +2178,7 @@ Map InverseAffineIterMap(const Array& iter_map, TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); +TVM_REGISTER_NODE_TYPE(IterMapResultNode); + } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 7d1f315b3cb3c..6abcc728fc8de 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -203,6 +203,8 @@ class PVar : public Pattern> { return value_; } + T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; } + protected: /*! \brief The matched value */ mutable T value_; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index dab78c77a0a1d..f9e38dee48e50 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -776,26 +776,32 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3), c1.Eval()->value > 0 && c3.Eval()->value > 0); - if (floordiv(x * c1, c2).Match(ret)) { + if (floordiv(x * c1 + y, c2).Match(ret) || floordiv(x * c1, c2).Match(ret) || + floordiv(y + x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - if (c1val > 0 && c2val > 0) { - if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval(); - if (c2val % c1val == 0) return floordiv(x, floordiv(c2, c1)).Eval(); + PrimExpr yval = y.EvalOr(Integer(0)); + if (c2val == 0) return ret; + + // try eliminate residue part + PrimExpr residue = + floordiv(x.Eval() * floormod(c1.Eval(), c2val) + floormod(yval, c2val), c2val); + PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); + auto bound = analyzer_->const_int_bound(residue); + if (bound.defined() && bound->max_value == bound->min_value) { + return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); } - } - if (floordiv(x * c1 + c2, c3).Match(ret)) { - int64_t c1val = c1.Eval()->value; - int64_t c2val = c2.Eval()->value; - int64_t c3val = c3.Eval()->value; - if (c1val > 0 && c3val > 0 && c3val % c1val == 0 && floormod(c2val, c3val) < c1val) { - // assume c3 == a * c1, x == a * y + b, c2 = d * c3 + e then - // (x * c1 + c2) // c3 - // ==> ((a * y + b) * c1 + d * a * c1 + e) // (a * c1) - // ==> y + d + (b * c1 + e) // c3 - // ==> y + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 - // ==> x // (c3 // c1) + (c2 // c3) - return (floordiv(x, floordiv(c3, c1)) + floordiv(c2, c3)).Eval(); + + // try simplify divisor + if (c1val > 0 && c2val > 0 && c2val % c1val == 0 && + CanProveLess(floormod(yval, c2val), c1val)) { + // assume c2 == a * c1, x == a * x' + b, y = d * c2 + e then + // (x * c1 + y) // c2 + // ==> ((a * x' + b) * c1 + d * a * c1 + e) // (a * c1) + // ==> x' + d + (b * c1 + e) // c2 + // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 + // ==> x // (c2 // c1) + (y // c2) + return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; } } @@ -804,28 +810,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -878,6 +868,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); + + TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); } return ret; } @@ -930,22 +922,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here - TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), + c2.Eval()->value != 0); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, c1.Eval()->value > 0 && c2.Eval()->value > 0 && c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), + c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), + c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 258f833a7b21b..202b9209da6df 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -110,6 +110,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { return analyzer_->CanProveGreaterEqual(x, val); } + // Whether x < val + bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); } // Whether x == val bool CanProveEqual(const PrimExpr& x, int64_t val) { // TODO(tqchen) refer back to super-analyzer. diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index ccf186634b8af..dffb8b4992851 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -75,13 +75,15 @@ inline std::vector ExprSplitAddition(const PrimExpr& expr) { } // Searches for the following types of expr: -// mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki -// mod_l_expr = c +// mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki +// mod_l_expr = c2 // mod_r_expr = k1 * k2 * ... * ki -// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) +// where c1 ~= c2 mod k1 * k2 * ... * ki +// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const PrimExpr& mult_expr, +inline std::pair MergeMulModInner(arith::Analyzer* analyzer, + const PrimExpr& mult_expr, const PrimExpr& mod_l_expr, const PrimExpr& mod_r_expr) { using namespace tir; @@ -119,9 +121,10 @@ inline std::pair MergeMulModInner(const PrimExpr& mult_expr, } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && - expr_equal(inner_div_ptr->a, mod_l_expr)) { + analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) { // Found! - PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; + PrimExpr ret = + no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a; return std::make_pair(true, ret); } else { return std::make_pair(false, PrimExpr()); @@ -204,7 +207,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { std::pair ret = - MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); + MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 77678d829a8e2..ba329676b1c33 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -76,17 +76,16 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto padded_iter_map = - DetectPaddedIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ false, &analyzer, - /* simplify_trivial_iterators = */ false); - CHECK(padded_iter_map.errors.empty()) << "Could not parse mapping as sum of iterators. " - << "Error: " << padded_iter_map.errors[0]; + auto padded_iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/arith::IterMapLevel::NoCheck, &analyzer, + /*simplify_trivial_iterators=*/false); + CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " + << "Error: " << padded_iter_map->errors[0]; // Determine expressions for the input variables, in terms of the // output variables. Map inverse_exprs_map = InverseAffineIterMap( - padded_iter_map.indices, Array(output_vars.begin(), output_vars.end())); + padded_iter_map->indices, Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; @@ -94,7 +93,7 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia inverse_exprs.push_back(inverse_exprs_map.at(index)); } - PrimExpr padding_predicate = padded_iter_map.padding_predicate; + PrimExpr padding_predicate = padded_iter_map->padding_predicate; padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); padding_predicate = Substitute(padding_predicate, inverse_exprs_map); @@ -141,14 +140,14 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // indices. arith::Analyzer analyzer; auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ true, &analyzer, + /* check_level = */ arith::IterMapLevel::Bijective, &analyzer, /* simplify_trivial_iterators = */ false); - CHECK(iter_map.size()) << "Index transformation was not bijective."; + CHECK(iter_map->indices.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the // output variables. - Map inverse_exprs_map = - InverseAffineIterMap(iter_map, Array(output_vars.begin(), output_vars.end())); + Map inverse_exprs_map = InverseAffineIterMap( + iter_map->indices, Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c4719015daa43..83ef6adae3b23 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -533,16 +533,16 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va if (loop_var_ranges.empty()) { return true; } - Array results = arith::DetectIterMap( + auto res = arith::DetectIterMap( /*indices=*/realize->iter_values, /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, - /*require_bijective=*/false, + /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/analyzer); - if (results.empty()) { + if (res->indices.empty()) { return false; } - for (const arith::IterSumExpr& sum_expr : results) { + for (const arith::IterSumExpr& sum_expr : res->indices) { const Array& args = sum_expr->args; if (!args.empty() && !is_one(args[0]->scale)) { return false; diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 993557f8be2f8..b0cafac3151f7 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -68,17 +68,18 @@ class SplitExprCollector { * \param index The indexing pattern * \param input_iters The input iterators' domain * \param predicate The predicate of the affine map - * \param require_bijective Whether the affine map is required to be bijective + * \param check_level The iter mapping checking level * \param analyzer The analyzer * \return The collected split expressions */ static std::vector Collect(const PrimExpr& index, const Map& input_iters, // const PrimExpr& predicate, // - bool require_bijective, // + arith::IterMapLevel check_level, // arith::Analyzer* analyzer) { - Array iter_sum_exprs = arith::DetectIterMap( - {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer); + arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, + predicate, check_level, analyzer); + const auto& iter_sum_exprs = res->indices; if (iter_sum_exprs.empty()) { return {}; } @@ -149,7 +150,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& // Step 3. Detect the IterSplitExpr of the indexing pattern std::vector split_exprs = SplitExprCollector::Collect( /*index=*/f_flatten_index(indices), input_iters, predicate, - /*require_bijective=*/false, analyzer); + /*check_level=*/arith::IterMapLevel::Surjective, analyzer); if (split_exprs.empty()) { return NullOpt; } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 7ed80a1c5b8f2..4ede2dd90da80 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -258,10 +258,9 @@ Array> CheckSubspaceDivisible(const IRModule& mod, arith::Analyzer* analyzer) { const Block& block = block_realize->block; - Array> division = - arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, - collector.inner_loop_vars, block_realize->predicate, - /*require_bijective=*/false, analyzer); + Array> division = arith::SubspaceDivide( + block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars, + block_realize->predicate, arith::IterMapLevel::Surjective, analyzer); if (division.empty()) { // If we can't do perfect subspace division, check if it is a trivial case of subspace division. diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 2a349f8fe61ed..7f1d74ac20214 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -244,7 +244,7 @@ class ScopeReconstructor : private StmtMutator { if (preserve_unit_loops || !is_one(iter_dom->extent)) { Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); loop_vars.push_back(var); - loop_extents.push_back(iter_dom->extent); + loop_extents.push_back(analyzer->Simplify(iter_dom->extent)); iter_values.push_back(iter_dom->min + var); analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent)); } else { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 452f72e7228f0..ad15e06e285af 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -552,13 +552,14 @@ class ReverseComputeInliner : public BaseInliner { } } - buffer_load_iter_map_ = arith::DetectIterMap( + auto res = arith::DetectIterMap( /*indices=*/buffer_load_indices_, /*input_iters=*/consumer_iter_doms, /*predicate=*/true, - /*require_bijective=*/true, + /*check_level=*/arith::IterMapLevel::Bijective, /*analyzer=*/&analyzer, /*simplify_trivial_iterators=*/false); + buffer_load_iter_map_ = res->indices; if (buffer_load_iter_map_.empty()) { // Failure: indices of BufferLoad are not bijective affine return false; diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 6da796fc955f3..692f68a600ae9 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -392,8 +392,9 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, auto iter_map = arith::DetectIterMap( /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true), - /*require_bijective=*/true, &analyzer, /*simplify_trivial_iterators=*/true); - if (iter_map.empty()) { + /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*simplify_trivial_iterators=*/true); + if (iter_map->indices.empty()) { throw NotBijectiveAffineIndexMapError(self->mod, index_map); } @@ -417,7 +418,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters // in the body. - auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars); + auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars); // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant // zero. for (const auto& iter_var : block_ptr->iter_vars) { diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index dbe6a3bbc0c5c..5315b139f0f6f 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -115,7 +115,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, - /*require_bijective=*/false); + /*check_level=*/arith::IterMapLevel::Surjective); if (v.same_as(op->iter_values)) { return GetRef(op); } else { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index fe766b921806b..d7bfa1c919478 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from xml import dom import tvm import tvm.testing -from tvm import te from tvm.tir import floormod, floordiv @@ -48,56 +48,69 @@ def convert_iter_expr(expr): return tvm.arith.normalize_iter_map_to_expr(expr) -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): - """Check the sum expr have the right pattern.""" - assert isinstance(sum_expr, tvm.arith.IterSumExpr) - if extent == 1: - assert len(sum_expr.args) == 0 - else: - assert len(sum_expr.args) == 1 - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) - tvm.testing.assert_prim_expr_equal(sum_expr.base, base) +def assert_iter_sum_pattern( + expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True +): + keys = list(expect_dict.keys()) + res = tvm.arith.detect_iter_map( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + indices = res.indices + assert len(indices) == len(keys), res.errors + print(indices) + for i, input_iter in enumerate(keys): + spec = expect_dict[input_iter] + ( + extent, + base, + ) = spec[0:2] + scale = spec[2] if len(spec) > 2 else 1 + expect_iter = spec[3] if len(spec) > 3 else None + sum_expr = indices[i] + assert isinstance(sum_expr, tvm.arith.IterSumExpr) + if extent == 1: + assert len(sum_expr.args) == 0 + else: + assert len(sum_expr.args) == 1 + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) + tvm.testing.assert_prim_expr_equal(sum_expr.base, base) + if expect_iter is not None: + if not isinstance(expect_iter, tvm.arith.IterMapExpr): + sum_expr = convert_iter_expr(sum_expr) + tvm.ir.assert_structural_equal(sum_expr, expect_iter) + + +def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"): + res = tvm.arith.detect_iter_map( + list(iters), dom_map, predicate=predicate, check_level=check_level + ).indices + assert len(res) == 0 def test_trivial(): - x = tvm.tir.Var("x", "int32"), 3 - y = tvm.tir.Var("y", "int32"), 4 - z = tvm.tir.Var("z", "int32"), 1 - - res = tvm.arith.detect_iter_map([x[0], y[0], 3], var_dom([x, y])) + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + dom_map = var_dom([(x, 3), (y, 4), (z, 1)]) - assert len(res) == 3 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - assert_iter_sum_pattern(res[2], 1, 3) - - res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 1, 3) + assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map) + assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map) # not independent - res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) - assert len(res) == 0 + assert_iter_sum_failure([x, x, 3], dom_map) - res = tvm.arith.detect_iter_map( - [x[0], y[0]], var_dom([x, y, z]), require_bijective=True, simplify_trivial_iterators=True + assert_iter_sum_pattern( + {x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True ) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - - res = tvm.arith.detect_iter_map( - [x[0], y[0]], var_dom([x, y, z]), require_bijective=True, simplify_trivial_iterators=False + assert_iter_sum_pattern( + {x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False ) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - - # not bijective - res = tvm.arith.detect_iter_map([x[0], z[0]], var_dom([x, y, z]), require_bijective=True) - assert len(res) == 0 + assert_iter_sum_failure([x, z], dom_map, check_level="bijective") def test_fuse(): @@ -106,42 +119,27 @@ def test_fuse(): c = tvm.tir.SizeVar("c", "int32") c0 = tvm.tir.SizeVar("c0", "int32") - res = tvm.arith.detect_iter_map([y * 3 + 1 + c + x], var_dom([(x, 3), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 12, 1 + c) + assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)])) - res = tvm.arith.detect_iter_map([ifuse([(x, 3), (y, 4)])[0]], var_dom([(x, 3), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 12, 0) + assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)])) # fuse with symbolic factor - res = tvm.arith.detect_iter_map([(y + 1) * c + x], var_dom([(x, c), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 4 * c, c) + assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)])) # duplication - res = tvm.arith.detect_iter_map([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 - - # duplication 2 - res = tvm.arith.detect_iter_map([y, x + 1, y], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 + assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) + assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)])) # factor mismatch - res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 + assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)])) # simple stride pattern - res = tvm.arith.detect_iter_map([x * 4 + y * 2], var_dom([(x, 3), (y, 2)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 6, 0, scale=2) - tvm.ir.assert_structural_equal(convert_iter_expr(res[0]), (x * 2 + y) * 2) + assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)])) # simple stride pattern with symbolic - res = tvm.arith.detect_iter_map([x * 2 * c0 + y * 2], var_dom([(x, 3), (y, c0)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 3 * c0, 0, scale=2) - tvm.ir.assert_structural_equal(convert_iter_expr(res[0]), (x * c0 + y) * 2) + assert_iter_sum_pattern( + {x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)]) + ) def test_split(): @@ -152,171 +150,138 @@ def test_split(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - res = tvm.arith.detect_iter_map([fld(x, 3), flm(x, 3) * 2 + c1], var_dom([(x, 24)])) + assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)])) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 8, 0) - assert_iter_sum_pattern(res[1], 3, c1, 2) - - res = tvm.arith.detect_iter_map([fld(x, 6), fld(flm(x, 6), 2), flm(x, 2)], var_dom([(x, 24)])) - - assert len(res) == 3 - assert_iter_sum_pattern(res[0], 4, 0) - assert_iter_sum_pattern(res[1], 3, 0) - assert_iter_sum_pattern(res[2], 2, 0) + assert_iter_sum_pattern( + {fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)]) + ) # simple symbolic bound # TODO(tvm-team) improve symbolic divisible check to enable # more complicated symbolic bound - res = tvm.arith.detect_iter_map([fld(x, c0), flm(x, c0)], var_dom([(x, c1 * c0)])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], c1, 0) - assert_iter_sum_pattern(res[1], c0, 0) - - res = tvm.arith.detect_iter_map([fld(x * 2, 4), flm(x * 2, 4)], var_dom([(x, 8)])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 4, 0, scale=1) - assert_iter_sum_pattern(res[1], 2, 0, scale=2) + assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)])) - res = tvm.arith.detect_iter_map([fld(x * 2, 4) * 4 + flm(x * 2, 4)], var_dom([(x, 8)])) + assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 8, 0, scale=2) + assert_iter_sum_pattern( + { + fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2), + }, + var_dom([(x, 8)]), + ) - res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) - assert len(res) == 0 + assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) def test_compound(): - x = tvm.tir.Var("x", "int32"), 10 - y = tvm.tir.Var("y", "int32"), 9 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") - xo, xi = isplit(x, 5) - yo, yi = isplit(y, 3) + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) - res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 18, 0) - assert_iter_sum_pattern(res[1], 5, 0) # reconstruct the pattern manually - mx = tvm.arith.IterMark(x[0], 10) - my = tvm.arith.IterMark(y[0], 9) - + mx = tvm.arith.IterMark(x, 10) + my = tvm.arith.IterMark(y, 9) xoscale = 3 - xiscale = 1 yoscale = 6 yiscale = 1 mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale) - mxi = tvm.arith.IterSplitExpr(mx, 1, 5, xiscale) myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) - mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) - tvm.ir.assert_structural_equal(sz, res[0]) + assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) def test_predicate(): - x = tvm.tir.Var("x", "int32"), 13 - y = tvm.tir.Var("y", "int32"), 10 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") # available contraints # upper bound only - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128 + ) + + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127 + ) # lower bound only - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 6) - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 6) + assert_iter_sum_pattern( + {x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5 + ) + + assert_iter_sum_pattern( + {x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6 + ) # lower bound + upper bound - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 6) - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), + + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 6) # constraint on one fused iter i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") k = tvm.tir.Var("k", "int32") - res = tvm.arith.detect_iter_map( - [i * 8 + j * 2 + k], + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (88, 1)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), + predicate=tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), ) - assert_iter_sum_pattern(res[0], 88, 1) # constraint on single var - res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10)) - assert_iter_sum_pattern(res[0], 10, 0) + assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10) - # iterations are subparts of constraint, invalid, case 1 - res = tvm.arith.detect_iter_map( + # iterations are subparts of constraint, invalid case 1 + assert_iter_sum_failure( [i, j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100), ) - assert len(res) == 0 - # iterations are subparts of constraint, invalid, case 2 - res = tvm.arith.detect_iter_map( + # iterations are subparts of constraint, invalid case 2 + assert_iter_sum_failure( [i * 128 + j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=i * 16384 + j * 128 + k < 100, ) - assert len(res) == 0 # irrelavant predicate - res = tvm.arith.detect_iter_map( - [i + j], - var_dom([(i, 1)]), - j <= 24, - ) - assert_iter_sum_pattern(res[0], 1, j) + assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24) # constraint on nested fused iters - res = tvm.arith.detect_iter_map( - [i * 8 + j * 2 + k], + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (22, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25), + predicate=tvm.tir.all( + 1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25 + ), ) - assert_iter_sum_pattern(res[0], 22, 3) # duplicate constraint on one fused iter - res = tvm.arith.detect_iter_map( - [i * 6 + j * 2 + k], + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (66, 2)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), + predicate=tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), ) - assert_iter_sum_pattern(res[0], 66, 2) # duplicate constraint on nested fused iters - res = tvm.arith.detect_iter_map( - [i * 6 + j * 2 + k], + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (15, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all( + predicate=tvm.tir.all( 1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, @@ -327,15 +292,13 @@ def test_predicate(): i * 6 + j * 2 + k < 18, ), ) - assert_iter_sum_pattern(res[0], 15, 3) # constraint on non-disjoint fused iters should fail - res = tvm.arith.detect_iter_map( + assert_iter_sum_failure( [i * 8 + j * 2 + k], var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), + predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), ) - assert len(res) == 0 # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) @@ -347,147 +310,135 @@ def test_predicate(): i3 = tvm.tir.Var("i3", "int32") i4 = tvm.tir.Var("i4", "int32") i5 = tvm.tir.Var("i5", "int32") - res = tvm.arith.detect_iter_map( - [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5], + assert_iter_sum_pattern( + {i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)}, var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), - tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), + predicate=tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), ) - assert_iter_sum_pattern(res[0], 540, 93) # constraint on many disjoint fused iters, case 2 - res = tvm.arith.detect_iter_map( - [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4], + assert_iter_sum_pattern( + {i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)}, var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), - tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10), + predicate=tvm.tir.all( + 3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10 + ), ) - assert_iter_sum_pattern(res[0], 135, 28) # constraint on split iters - res = tvm.arith.detect_iter_map( - [i % 16, i // 16], + assert_iter_sum_pattern( + {i % 16: (7, 3), i // 16: (8, 4)}, var_dom([(i, 1024)]), - tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), - require_bijective=True, + predicate=tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), + check_level="bijective", ) - assert_iter_sum_pattern(res[0], 7, 3) - assert_iter_sum_pattern(res[1], 8, 4) # constraint on split iters, nested case 1 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (7, 3)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), + predicate=tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), ) - assert_iter_sum_pattern(res[0], 7, 3) # constraint on split iters, nested case 2 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_failure( + [ + (i * 32 + j) % 16, + ], var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + check_level="bijective", ) - assert len(res) == 0 - res = tvm.arith.detect_iter_map( - [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16], + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (16, 0)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + ) + assert_iter_sum_pattern( + {(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)}, + var_dom([(i, 5), (j, 32)]), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), ) - assert_iter_sum_pattern(res[0], 16, 0) - assert_iter_sum_pattern(res[1], 4, 0) # non-standard form of predicate - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y + ) # duplicate constraint - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.all(x[0] * 10 + y[0] < 128, x[0] * 10 + y[0] < 64), + assert_iter_sum_pattern( + {x * 10 + y: (64, 0)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 64, 0) - # useless constraint - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 140) - - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 130, 0) + assert_iter_sum_pattern( + {x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140 + ) - i1 = tvm.tir.Var("i1", "int32"), 7 - i2 = tvm.tir.Var("i2", "int32"), 2 - i3 = tvm.tir.Var("i3", "int32"), 4 - i4 = tvm.tir.Var("i4", "int32"), 3 - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + assert_iter_sum_pattern( + {i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)}, + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 10, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, ) ), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) - - i1 = tvm.tir.Var("i1", "int32"), 7 - i2 = tvm.tir.Var("i2", "int32"), 2 - i3 = tvm.tir.Var("i3", "int32"), 4 - i4 = tvm.tir.Var("i4", "int32"), 3 # wrong constraint - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 7, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 7, ) ), ) - assert len(res) == 0 # incompatible constraint - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 10, - i1[0] * 4 + i3[0] < 20, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, + i1 * 4 + i3 < 20, ) ), ) - assert len(res) == 0 - - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i1[0] * 4 + i3[0] < 20, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i1 * 4 + i3 < 20, ) ), ) - assert len(res) == 0 # zero iter - xo = tvm.tir.Var("xo", "int32"), 1 - xi = tvm.tir.Var("xi", "int32"), 129 - y = tvm.tir.Var("y", "int32"), 128 - - res = tvm.arith.detect_iter_map( - [xo[0] * 129 + xi[0], y[0]], var_dom([xo, xi, y]), xo[0] * 129 + xi[0] < 128 + xo = tvm.tir.Var("xo", "int32") + xi = tvm.tir.Var("xi", "int32") + y = tvm.tir.Var("y", "int32") + assert_iter_sum_pattern( + {xo * 129 + xi: (128, 0), y: (128, 0)}, + var_dom([(xo, 1), (xi, 129), (y, 128)]), + predicate=xo * 129 + xi < 128, ) @@ -554,9 +505,10 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + assert_iter_sum_pattern + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.2 @@ -568,9 +520,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.3 @@ -589,9 +541,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) tvm.ir.assert_structural_equal(res[2][1], True) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.5 @@ -607,9 +559,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], True) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.6 @@ -644,9 +596,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.2 @@ -662,9 +614,11 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])) + res1 = tvm.arith.detect_iter_map( + [res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3]) + ).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices assert len(res2) == 3 # compound 2.3 @@ -692,9 +646,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.5 @@ -730,13 +684,6 @@ def test_complex(): i0 = ifuse([j0, j1], 200) i1 = ifuse([j2, j3], 50) - res = tvm.arith.detect_iter_map( - [i0[0], i1[0]], - var_dom([l0, l1, n0, n1, m1, l3]), - tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), - ) - assert len(res) == 2 - n0_mark = tvm.arith.IterMark(n0[0], n0[1]) n1_mark = tvm.arith.IterMark(n1[0], n1[1]) l0_mark = tvm.arith.IterMark(l0[0], l0[1]) @@ -784,16 +731,20 @@ def test_complex(): i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) - tvm.ir.assert_structural_equal(i0_final, res[0]) - tvm.ir.assert_structural_equal(i1_final, res[1]) + assert_iter_sum_pattern( + {i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)}, + var_dom([l0, l1, n0, n1, m1, l3]), + predicate=tvm.tir.all( + i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15 + ), + ) # wrong constraint - res = tvm.arith.detect_iter_map( + assert_iter_sum_failure( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), ) - assert len(res) == 0 # subspace_division res = tvm.arith.subspace_divide( @@ -822,34 +773,33 @@ def test_complex(): ), ) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([n0, n1, m1, l3]), res[2][1]) - assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([l0, l1])) - assert len(res2) == 2 + assert_iter_sum_pattern( + {res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1] + ) + assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1])) def test_normalize_iter_map_to_expr(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - x = tvm.tir.Var("x", "int32"), 10 - y = tvm.tir.Var("y", "int32"), 9 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") - xo, xi = isplit(x, 5) - yo, yi = isplit(y, 3) + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) - - res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)])) tvm.ir.assert_structural_equal( - tvm.arith.normalize_iter_map_to_expr(res[0]), - fld(y[0], 3) * 6 + fld(x[0], 5) * 3 + flm(y[0], 3), + tvm.arith.normalize_iter_map_to_expr(res.indices[0]), + fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3), ) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5)) # iter mark wrap a complex expr - split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1) + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1) def test_inverse_affine_iter_map(): @@ -863,7 +813,9 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 4) l0_1_l1_1_fused = ifuse([l0_1, l1_1]) - iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) + iter_map = tvm.arith.detect_iter_map( + [l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1]) + ).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 @@ -882,7 +834,7 @@ def test_inverse_affine_iter_map(): iter_map = tvm.arith.detect_iter_map( [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) - ) + ).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 @@ -902,7 +854,7 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 8) l2 = ifuse([l1_1, l1_0]) - iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 @@ -918,12 +870,11 @@ def test_free_variables(): z = tvm.tir.Var("z", "int32") # illegal iter if z is within dom - res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) - assert len(res) == 0 + assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) # iter is valid if z is free, even there are linear forms of z - res = tvm.arith.detect_iter_map( - [z * 19 + y * 3 + x], + assert_iter_sum_pattern( + {z * 19 + y * 3 + x: (9, z * 19)}, var_dom( [ (x, 3), @@ -931,9 +882,8 @@ def test_free_variables(): ] ), ) - assert_iter_sum_pattern(res[0], 9, z * 19) - res = tvm.arith.detect_iter_map( - [z * z + y * 3 + x], + assert_iter_sum_pattern( + {z * z + y * 3 + x: (9, z * z)}, var_dom( [ (x, 3), @@ -941,7 +891,105 @@ def test_free_variables(): ] ), ) - assert_iter_sum_pattern(res[0], 9, z * z) + + +def test_padding(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + # left padding only, offset divisible + sum = 64 + y + dom_map = var_dom([(y, 192)]) + assert_iter_sum_pattern( + {fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)}, + dom_map, + check_level="bijective", + ) + + # left padding only, offset non-divisible + sum = 80 + y + dom_map = var_dom([(y, 176)]) + assert_iter_sum_pattern( + {fld(sum, 32): (6, 2, 1)}, + dom_map, + ) + assert_iter_sum_pattern( + {flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)}, + dom_map, + ) + assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map) + assert_iter_sum_failure({fld(sum, 32), fld(sum, 4)}, dom_map) + + # right padding only, offset divisible + sum = x * 32 + y * 8 + dom_map = var_dom([(x, 5), (y, 4)]) + assert_iter_sum_pattern( + {fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)}, + dom_map, + ) + assert_iter_sum_failure({fld(sum, 5)}, dom_map) + + # right padding only, offset non-divisible + dom_map = var_dom([(x, 26)]) + assert_iter_sum_pattern( + {fld(x, 15): (2, 0, 1)}, + dom_map, + ) + assert_iter_sum_pattern( + {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}, + dom_map, + ) + + # padding constants on both side + sum = x + 71 + dom_map = var_dom([(x, 45)]) + assert_iter_sum_pattern({fld(sum, 32): (2, 2, 1)}, dom_map) + assert_iter_sum_pattern( + {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}, + dom_map, + ) + + # padding for free iteration part + sum = x * 360 + y + dom_map = var_dom([(y, 360)]) + assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map) + assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map) + + # multiple split with same mark offset, could + # be surjective on missing (padded // LCM) + assert_iter_sum_pattern( + { + flm(x + 10, 3): (3, 0), + flm(fld(x + 10, 3), 4): (4, 0), + flm(fld(fld(x + 10, 3), 4), 5): (5, 0), + }, + var_dom([(x, 240)]), + ) + assert_iter_sum_failure( + { + flm(x + 10, 3), + flm(fld(x + 10, 3), 4), + flm(fld(fld(x + 10, 3), 4), 5), + fld(fld(fld(x + 10, 3), 4), 5), + }, + var_dom([(x, 240)]), + ) + + # different offsets on splits + assert_iter_sum_pattern( + { + flm(x + 1, 3): (3, 0), + flm(fld(x + 10, 3) + 2, 4): (4, 0), + flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0), + }, + var_dom([(x, 240)]), + ) + + # original extent is smaller than the divident + # it is not surjective wrt to the region [0, 16) + assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)])) if __name__ == "__main__": diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 8d26710f40dbf..82e1372f991e1 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -459,11 +459,13 @@ def test_div_index_simplify(): def test_floordiv_index_simplify(): # short name for floordiv fld = tvm.te.floordiv + flm = tvm.te.floormod ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") ck.verify(fld(fld(x, 2), 3), fld(x, 6)) ck.verify(fld(fld(x, 2) + 1, 3), fld(x + 2, 6)) + ck.verify(fld(x - flm(x, 21), 21), fld(x, 21)) ck.verify(fld(x * 2, 4), fld(x, 2)) ck.verify(fld(x * 4, 2), x * 2) @@ -472,11 +474,17 @@ def test_floordiv_index_simplify(): ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + -1, 16)) ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True) + ck.verify(fld(x * 360 + y, 16), x * 22) + ck.verify(fld(x * 360 + y, 25), x * 14) + ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)) + ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))) ck.verify(fld(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, fld(y, 2))) - ck.verify(fld(y + x * 4, 2), fld(y, 2) + x * 2) + ck.verify(fld(y + x * 4, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)) ck.verify(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)) @@ -549,15 +557,17 @@ def test_mod_index_simplify(): def test_floormod_index_simplify(): # short name for floordiv flm = tvm.te.floormod - ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") ck = RewriteChecker() x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), te.var("z") ck.verify(flm(x * 10, 2), 0) + ck.verify(flm(x * 9600, 6400), flm(x * 3200, 6400)) ck.verify(flm(x * 10 + y, 2), flm(y, 2)) + ck.verify(flm(x * 360 + y, 16), flm(x * 8 + y, 16)) ck.verify(flm(x + 10, 2), flm(x, 2)) ck.verify(flm(x + y * 10, 2), flm(x, 2)) + ck.verify(flm(x + y * 360, 16), flm(x + y * 8, 16)) ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1) ck.verify(flm(x * (-10), 2), 0) ck.verify(flm(x * (-10) + y, 2), flm(y, 2)) diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 337f9cbc07223..10e827978cc0a 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -137,6 +137,7 @@ def assert_simplified_equal(index_simplified, index_direct): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod + # Test Case1 index_simplified = A_stride.offset_of( (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1) @@ -174,7 +175,7 @@ def assert_simplified_equal(index_simplified, index_direct): j = te.size_var("j") k = te.size_var("k") - index_simplified = B.offset_of( + index_simplified1 = B.offset_of( ( idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), @@ -182,8 +183,17 @@ def assert_simplified_equal(index_simplified, index_direct): idxm((i * 50176 + j * 28672 + k), 1024), ) ) + index_simplified2 = B.offset_of( + ( + idxd(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14), + idxm(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14), + idxm(i * 7 + idxd(k, 1024), 14), + idxm(k, 1024), + ) + ) index_direct = B.offset_of((0, 0, 0, (i * 50176 + j * 28672 + k))) - assert_simplified_equal(index_simplified, index_direct) + assert_simplified_equal(index_simplified1, index_direct) + assert_simplified_equal(index_simplified2, index_direct) @tvm.testing.requires_llvm diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index b06dcebe1d1c5..f477367adfad3 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1249,6 +1249,44 @@ def test_compute_at_simplify_static_bound(): verify_trace_roundtrip(sch=sch, mod=static_bound) +def test_compute_at_non_perfect_channel_group(): + @T.prim_func + def grouped_channel_bias( + X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] + ): + B = T.alloc_buffer([45], dtype="float32", scope="") + for i in T.grid(45): + with T.block("init"): + vi = T.axis.remap("S", [i]) + B[vi] = vi + for c_o, h, w, c_i in T.grid(2, 8, 8, 360): + with T.block("compute"): + hh, ww = T.axis.remap("SS", [h, w]) + cc = T.axis.spatial(720, c_o * 360 + c_i) + Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] + + @T.prim_func + def grouped_channel_bias_non_perfect_tiled( + X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] + ): + B = T.alloc_buffer([45], dtype="float32") + for c_o in range(2): + for ax0 in range(23): + with T.block("init"): + vi = T.axis.spatial(45, c_o * 22 + ax0) + B[vi] = vi + for h, w, c_i in T.grid(8, 8, 360): + with T.block("compute"): + hh, ww = T.axis.remap("SS", [h, w]) + cc = T.axis.spatial(720, c_o * 360 + c_i) + Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] + + sch = tir.Schedule(grouped_channel_bias, debug_mask="all") + loop = sch.get_loops(sch.get_block("compute"))[0] + sch.compute_at(sch.get_block("init"), loop) + tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) + + def test_fail_subtree_complete_block(): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0") From ac5d7813dff34566645787c9f3f2e6576dd723da Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 31 May 2022 20:03:04 +0100 Subject: [PATCH 003/181] [microNPU] Fix flaky compute cycle annotation test (#11510) Fixes non-deterministic test by disabling striping when running the cascader. Change-Id: Ib44f299f21fa0b41be4bfac3deb61a9c16818c58 --- tests/python/contrib/test_ethosu/cascader/test_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/cascader/test_scheduler.py b/tests/python/contrib/test_ethosu/cascader/test_scheduler.py index b3610315441ef..2dce6dfdd67ed 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/cascader/test_scheduler.py @@ -48,7 +48,6 @@ def test_cascade(SRAM, FLASH, TwoConv2DWithSliceTE, TwoConv2DTE, MobileNetv1Star cs.cascade(sch, te_graph, const_dict, options, SRAM, FLASH, [SRAM], device_config) -@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11483") def test_compute_cycles_annotation(SRAM, FLASH, TwoConv2DTE): device_config = cs.EthosuDeviceConfig("ethos-u55-256") options = infra.make_options( @@ -61,6 +60,7 @@ def test_compute_cycles_annotation(SRAM, FLASH, TwoConv2DTE): always_copy_size=1024, disable_pareto_plans=False, disable_pareto_proposals=False, + enable_striping=False, ) sch, te_graph, const_dict = TwoConv2DTE cs.cascade(sch, te_graph, const_dict, options, SRAM, FLASH, [SRAM], device_config) @@ -69,7 +69,7 @@ def test_compute_cycles_annotation(SRAM, FLASH, TwoConv2DTE): # [copy, copy, conv2d, copy, conv2d] stages = [6, 8, 9, 18, 19] # Expected hints for each operation - compute_cycles_hints = [4096, 5120, 1632, 2560, 3072] + compute_cycles_hints = [4096, 5120, 1440, 2560, 3072] for stage, compute_cycles_hint in zip(stages, compute_cycles_hints): op = sch.stages[stage] From 2252f958f75c6e33b946d23f1ebb803d41f0b63d Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 31 May 2022 13:27:01 -0700 Subject: [PATCH 004/181] [microTVM][ARM][Zephyr] Add CMSIS dependencies in Zephyr project build (#11362) * Test with CMSIS build added disabled conv2d_nhwc_dsp.arm_cpu for non integers workloads added debugging feature to TempDirectory * revert arm_cpu strategy changes * Address Andrew comments * change copy to include * add cmsis_path only as project option --- .../template_project/microtvm_api_server.py | 45 ++++++++++-- python/tvm/contrib/utils.py | 14 ++-- tests/micro/zephyr/conftest.py | 21 +++++- tests/micro/zephyr/test_zephyr.py | 70 +++++++++++++++++++ tests/micro/zephyr/test_zephyr_aot.py | 2 + tests/micro/zephyr/test_zephyr_armv7m.py | 1 + tests/scripts/task_python_microtvm.sh | 1 + 7 files changed, 144 insertions(+), 10 deletions(-) diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 059e7604896c0..bcf9f78f4b112 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -27,7 +27,6 @@ import pathlib import queue import re -import select import shlex import shutil import subprocess @@ -35,7 +34,7 @@ import tarfile import tempfile import threading -import time +from typing import Union import usb import serial @@ -323,6 +322,12 @@ def _get_nrf_device_args(options): type="str", help="Extra definitions added project compile.", ), + server.ProjectOption( + "cmsis_path", + optional=["generate_project"], + type="str", + help="Path to the CMSIS directory.", + ), ] @@ -333,6 +338,13 @@ def get_zephyr_base(options: dict): return zephyr_base +def get_cmsis_path(options: dict) -> pathlib.Path: + """Returns CMSIS dependency path""" + cmsis_path = options.get("cmsis_path") + assert cmsis_path, "'cmsis_path' option not passed!" + return pathlib.Path(cmsis_path) + + class Handler(server.ProjectAPIHandler): def __init__(self): super(Handler, self).__init__() @@ -424,6 +436,17 @@ def _get_platform_version(self, zephyr_base: str) -> float: return float(f"{version_major}.{version_minor}") + def _cmsis_required(self, project_path: Union[str, pathlib.Path]) -> bool: + """Check if CMSIS dependency is required.""" + project_path = pathlib.Path(project_path) + for path in (project_path / "codegen" / "host" / "src").iterdir(): + if path.is_file(): + with open(path, "r") as lib_f: + lib_content = lib_f.read() + if "" in lib_content and "" in lib_content: + return True + return False + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Check Zephyr version version = self._get_platform_version(get_zephyr_base(options)) @@ -470,8 +493,8 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec shutil.copy2(src_path, dst_path) # Populate Makefile. - with open(API_SERVER_DIR / "CMakeLists.txt.template", "r") as cmake_template_f: - with open(project_dir / "CMakeLists.txt", "w") as cmake_f: + with open(project_dir / "CMakeLists.txt", "w") as cmake_f: + with open(API_SERVER_DIR / "CMakeLists.txt.template", "r") as cmake_template_f: for line in cmake_template_f: if self.API_SERVER_CRT_LIBS_TOKEN in line: crt_libs = self.CRT_LIBS_BY_PROJECT_TYPE[options["project_type"]] @@ -484,6 +507,20 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec for item in flags: cmake_f.write(f"target_compile_definitions(app PUBLIC {item})\n") + # Include CMSIS libraries if required. + if self._cmsis_required(extract_path): + cmsis_path = get_cmsis_path(options) + cmake_f.write("\n") + cmake_f.write( + f'target_include_directories(tvm_model PRIVATE {str(cmsis_path / "CMSIS" / "DSP" / "Include")})\n' + ) + cmake_f.write( + f'target_include_directories(tvm_model PRIVATE {str(cmsis_path / "CMSIS" / "DSP" / "Include" / "dsp")})\n' + ) + cmake_f.write( + f'target_include_directories(tvm_model PRIVATE {str(cmsis_path / "CMSIS" / "NN" / "Include")})\n' + ) + self._create_prj_conf(project_dir, options) # Populate crt-config.h diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index e2ca182779c6f..89688b5bf86f4 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -93,11 +93,15 @@ def set_keep_for_debug(cls, set_to=True): finally: cls._KEEP_FOR_DEBUG = old_keep_for_debug - def __init__(self, custom_path=None): + def __init__(self, custom_path=None, keep_for_debug=None): if self.TEMPDIRS is None: raise DirectoryCreatedPastAtExit() - self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG + if keep_for_debug is not None: + self._created_with_keep_for_debug = keep_for_debug + else: + self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG + if custom_path: os.mkdir(custom_path) self.temp_dir = custom_path @@ -169,7 +173,7 @@ def listdir(self): atexit.register(TempDirectory.remove_tempdirs) -def tempdir(custom_path=None): +def tempdir(custom_path=None, keep_for_debug=None): """Create temp dir which deletes the contents when exit. Parameters @@ -177,12 +181,14 @@ def tempdir(custom_path=None): custom_path : str, optional Manually specify the exact temp dir path + keep_for_debug : bool + Keep temp directory for debugging purposes Returns ------- temp : TempDirectory The temp directory object """ - return TempDirectory(custom_path) + return TempDirectory(custom_path=custom_path, keep_for_debug=keep_for_debug) class FileLock(object): diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 177ca8aa269e8..997237d370a5d 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -59,7 +59,7 @@ def tvm_debug(request): @pytest.fixture -def temp_dir(board): +def temp_dir(board, tvm_debug): parent_dir = pathlib.Path(os.path.dirname(__file__)) filename = os.path.splitext(os.path.basename(__file__))[0] board_workspace = ( @@ -76,4 +76,21 @@ def temp_dir(board): if not os.path.exists(board_workspace.parent): os.makedirs(board_workspace.parent) - return tempdir(board_workspace) + keep_for_debug = tvm_debug if tvm_debug else None + test_temp_dir = tempdir(custom_path=board_workspace, keep_for_debug=keep_for_debug) + return test_temp_dir + + +@pytest.fixture(autouse=True) +def skip_by_board(request, board): + """Skip test if board is in the list.""" + if request.node.get_closest_marker("skip_boards"): + if board in request.node.get_closest_marker("skip_boards").args[0]: + pytest.skip("skipped on this board: {}".format(board)) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "skip_by_board(board): skip test for the given board", + ) diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index f89d11cf44dcc..2651435434b11 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -22,6 +22,7 @@ import pytest import numpy as np + import onnx from PIL import Image @@ -32,6 +33,7 @@ from tvm.relay.testing import byoc from tvm.contrib import utils from tvm.micro.testing.utils import check_tune_log +from tvm.target import arm_isa import test_utils @@ -87,6 +89,7 @@ def _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config, dtype= # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_add_uint(temp_dir, board, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" @@ -112,6 +115,7 @@ def test_basic_add(sess): # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_add_float(temp_dir, board, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" model = test_utils.ZEPHYR_BOARDS[board] @@ -138,6 +142,7 @@ def test_basic_add(sess): @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_platform_timer(temp_dir, board, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" @@ -167,6 +172,7 @@ def test_basic_add(sess): @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_relay(temp_dir, board, west_cmd, tvm_debug): """Testing a simple relay graph""" model = test_utils.ZEPHYR_BOARDS[board] @@ -199,6 +205,7 @@ def test_relay(temp_dir, board, west_cmd, tvm_debug): @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_onnx(temp_dir, board, west_cmd, tvm_debug): """Testing a simple ONNX model.""" model = test_utils.ZEPHYR_BOARDS[board] @@ -279,6 +286,7 @@ def check_result( @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_byoc_microtvm(temp_dir, board, west_cmd, tvm_debug): """This is a simple test case to check BYOC capabilities of microTVM""" model = test_utils.ZEPHYR_BOARDS[board] @@ -359,6 +367,7 @@ def _make_add_sess_with_shape(temp_dir, model, zephyr_board, west_cmd, shape, bu ], ) @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_rpc_large_array(temp_dir, board, west_cmd, tvm_debug, shape): """Test large RPC array transfer.""" model = test_utils.ZEPHYR_BOARDS[board] @@ -504,5 +513,66 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): tvm.testing.assert_allclose(output, expected_output, rtol=1e-4, atol=1e-5) +@tvm.testing.requires_micro +def test_schedule_build_with_cmsis_dependency(temp_dir, board, west_cmd, tvm_debug): + """Test Relay schedule with CMSIS dependency. This test shows if microTVM Auto tuning + with Zephyr breaks if CMSIS dependency was required for a schedule. + """ + model = test_utils.ZEPHYR_BOARDS[board] + build_config = {"debug": tvm_debug} + target = tvm.target.target.micro(model, options=["-keys=arm_cpu,cpu"]) + + isa = arm_isa.IsaAnalyzer(target) + if not isa.has_dsp_support: + pytest.skip(f"ISA does not support DSP. target: {target}") + + # Create a Relay conv2d + data_shape = (1, 16, 16, 3) + weight_shape = (5, 5, 8, 3) + data = relay.var("data", relay.TensorType(data_shape, "int8")) + weight = relay.var("weight", relay.TensorType(weight_shape, "int8")) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + data_layout="NHWC", + kernel_layout="HWOI", + out_dtype="int32", + ) + func = relay.Function([data, weight], y) + ir_mod = tvm.IRModule.from_expr(func) + + runtime = Runtime("crt", {"system-lib": True}) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(ir_mod, target=target, runtime=runtime) + + project_options = { + "project_type": "host_driven", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": board, + "cmsis_path": os.getenv("CMSIS_PATH"), + } + + project_dir = temp_dir / "project" + project = tvm.micro.generate_project( + str(test_utils.TEMPLATE_PROJECT_DIR), + mod, + project_dir, + project_options, + ) + project.build() + + with open(project_dir / "CMakeLists.txt", "r") as cmake_f: + cmake_content = cmake_f.read() + + assert "CMSIS/DSP/Include" in cmake_content + assert "CMSIS/DSP/Include/dsp" in cmake_content + assert "CMSIS/DSP/Include" in cmake_content + assert "CMSIS/NN/Include" in cmake_content + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index cfe2ce2ae3c8f..3d509f100d6ec 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -38,6 +38,7 @@ @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_tflite(temp_dir, board, west_cmd, tvm_debug): """Testing a TFLite model.""" model = test_utils.ZEPHYR_BOARDS[board] @@ -93,6 +94,7 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_qemu_make_fail(temp_dir, board, west_cmd, tvm_debug): """Testing QEMU make fail.""" if board not in ["qemu_x86", "mps2_an521", "mps3_an547"]: diff --git a/tests/micro/zephyr/test_zephyr_armv7m.py b/tests/micro/zephyr/test_zephyr_armv7m.py index 2631e43799668..c629403ced821 100644 --- a/tests/micro/zephyr/test_zephyr_armv7m.py +++ b/tests/micro/zephyr/test_zephyr_armv7m.py @@ -103,6 +103,7 @@ def _apply_desired_layout_no_simd(relay_mod): @tvm.testing.requires_micro +@pytest.mark.skip_boards(["mps2_an521"]) def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug): """Testing a ARM v7m SIMD extension.""" diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 557e938a6ed3a..2274c6ca6b283 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -27,6 +27,7 @@ make cython3 run_pytest ctypes python-microtvm-zephyr-qemu_x86 tests/micro/zephyr --zephyr-board=qemu_x86 run_pytest ctypes python-microtvm-zephyr-qemu_riscv32 tests/micro/zephyr --zephyr-board=qemu_riscv32 run_pytest ctypes python-microtvm-zephyr-qemu_riscv64 tests/micro/zephyr --zephyr-board=qemu_riscv64 +run_pytest ctypes python-microtvm-zephyr-mps2_an521 tests/micro/zephyr --zephyr-board=mps2_an521 # Arduino run_pytest ctypes python-microtvm-arduino apps/microtvm/arduino/template_project/tests From a71536a130685a50582eea8c993030872cddb145 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 31 May 2022 15:57:30 -0700 Subject: [PATCH 005/181] [MetaSchedule] Enable Task Filtering (#11512) This PR allows `relay.backend.MetaScheduleExtractTask` to take an extra argument `filter_func` which filters out tasks that don't need tuning. The counterpart of AutoScheduler is `traverse_to_get_io_tensors`. --- python/tvm/meta_schedule/relay_integration.py | 8 +- python/tvm/te/__init__.py | 2 +- python/tvm/te/operation.py | 29 ++----- src/relay/backend/task_extraction.cc | 80 +++++++++++++------ src/te/operation/create_primfunc.cc | 33 -------- src/te/operation/create_primfunc.h | 3 - src/tir/schedule/concrete_schedule.cc | 20 ++--- .../test_meta_schedule_integration.py | 63 +++++++++++++++ 8 files changed, 140 insertions(+), 98 deletions(-) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 47f76830ab88f..b556338174130 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """MetaSchedule-Relay integration""" -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import numpy as np # type: ignore from tvm import nd @@ -23,6 +23,7 @@ from tvm.ir import IRModule, transform from tvm.runtime import NDArray from tvm.target import Target +from tvm.te import Tensor from .extracted_task import ExtractedTask from .utils import autotvm_silencer @@ -36,6 +37,7 @@ def extract_task_from_relay( opt_level: int = 3, pass_config: Optional[Dict[str, Any]] = None, disabled_pass: Optional[List[str]] = None, + filter_func: Callable[[List[Tensor]], bool] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -53,6 +55,8 @@ def extract_task_from_relay( The pass config of the compiler disabled_pass : Optional[List[str]] The list of disabled passes of the compiler + filter_func : Callable[[List[tvm.te.Tensor]], bool] + The filter function to filter out the extracted tasks Returns ------- @@ -90,4 +94,4 @@ def extract_task_from_relay( config=pass_config, disabled_pass=disabled_pass, ): - return list(extract_task_func(mod, target, relay_params)) + return list(extract_task_func(mod, target, relay_params, filter_func)) diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 4c4e223f2d723..1777d8707c7ce 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -39,7 +39,7 @@ from .tag import tag_scope from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis -from .operation import create_prim_func, create_prim_func_from_outputs +from .operation import create_prim_func from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 90d7cb5d75dbc..df5dd2c4ffd81 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -15,17 +15,18 @@ # specific language governing permissions and limitations # under the License. """ Operation class for computation declaration.""" +import inspect + # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List, Union -import inspect +from typing import List import tvm._ffi +import tvm.tir +import tvm.tir._ffi_api from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert -import tvm.tir -import tvm.tir._ffi_api from . import _ffi_api from . import tag as _tag @@ -528,23 +529,3 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops) - - -def create_prim_func_from_outputs( - outputs: Union[_tensor.Tensor, List[_tensor.Tensor]], -) -> tvm.tir.PrimFunc: - """Create a TensorIR PrimFunc from output tensor(s) in TE - - Parameters - ---------- - outputs : Union[Tensor, List[Tensor]] - The source expression. - - Returns - ------- - func : tir.PrimFunc - The created function. - """ - if not isinstance(outputs, (list, tuple, Array)): - outputs = [outputs] - return _ffi_api.CreatePrimFuncFromOutputs(outputs) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 0895fd42a3077..6ec881111d770 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -31,25 +31,58 @@ namespace tvm { namespace relay { namespace backend { -namespace metaschedule { - -using meta_schedule::ExtractedTask; +bool DefaultTaskFilter(const Array& args) { + using namespace ::tvm::te; + std::vector stack; + std::unordered_set visited; + for (const Tensor& v : args) { + for (const PrimExpr& e : v->shape) { + // Dynamic shape is not supported for now + if (!e->IsInstance()) { + return false; + } + } + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + // do nothing + } else if (tensor->op->IsInstance()) { + Array inputs = tensor->op->InputTensors(); + for (const Tensor& v : inputs) { + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + } else { + return false; + } + } + return true; +} -Array ExtractTask(IRModule mod, Target target, - Map params) { +Array ExtractTask( + IRModule mod, Target target, Map params, + runtime::TypedPackedFunc&)> filter_func) { + using meta_schedule::ExtractedTask; + if (filter_func == nullptr) { + filter_func = DefaultTaskFilter; + } backend::BindParamsInModule(mod, params); - // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); - - transform::Sequential seq(pass_seqs); - auto opt_mod = seq(std::move(mod)); + mod = transform::Sequential(pass_seqs)(std::move(mod)); std::vector tasks; std::unordered_map cache; - - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) { + PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &filter_func](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { @@ -61,17 +94,19 @@ Array ExtractTask(IRModule mod, Target target, it->second->weight += 1; return; } - Array inputs_outputs; + Array inputs_outputs{nullptr}; std::string fused_name; std::tie(inputs_outputs, fused_name) = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - auto prim_func = tir::CreatePrimFunc(inputs_outputs); - GlobalVar prim_fn_var(fused_name); - IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, prim_func}}); - ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); - tasks.push_back(extracted_task); - cache.emplace(cache_key, extracted_task); + if (filter_func(inputs_outputs)) { + tir::PrimFunc prim_func = tir::CreatePrimFunc(inputs_outputs); + GlobalVar prim_fn_var(fused_name); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, prim_func}}); + ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); + tasks.push_back(extracted_task); + cache.emplace(cache_key, extracted_task); + } } }); // Tasks are extracted via post order visit, return the reversed list. @@ -83,12 +118,7 @@ Array ExtractTask(IRModule mod, Target target, return tasks; } -} // namespace metaschedule - -TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask") - .set_body_typed([](IRModule mod, Target target, Map params) { - return metaschedule::ExtractTask(mod, target, params); - }); +TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask").set_body_typed(ExtractTask); } // namespace backend } // namespace relay diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 7e7dae855802f..03ad551c68391 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -458,40 +458,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); } -PrimFunc CreatePrimFuncFromOutputs(const Array& outputs) { - std::vector stack; - std::unordered_set visited; - for (const te::Tensor& output : outputs) { - if (!visited.count(output.get())) { - visited.insert(output.get()); - stack.push_back(output); - } - } - - Array arg_list; - while (!stack.empty()) { - te::Tensor tensor = stack.back(); - stack.pop_back(); - if (tensor->op->IsInstance()) { - arg_list.push_back(tensor); - } else if (tensor->op->IsInstance()) { - Array inputs = tensor->op->InputTensors(); - for (const te::Tensor& input : inputs) { - if (!visited.count(input.get())) { - visited.insert(input.get()); - stack.push_back(input); - } - } - } - } - for (const te::Tensor& output : outputs) { - arg_list.push_back(output); - } - return CreatePrimFunc(arg_list); -} - TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); -TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs); } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index d911e5ebcdb7d..c3cddd83f57a8 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,9 +30,6 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list); -/*! \brief Create a schedulable TensorIR func from TE compute outputs. */ -PrimFunc CreatePrimFuncFromOutputs(const Array& outputs); - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8066d85a8e7db..2289899c329bb 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -199,16 +199,16 @@ Schedule ConcreteScheduleNode::Copy() { * \param level An ScheduleErrorRenderLevel enum, level of error rendering * \sa ScheduleErrorRenderLevel */ -#define TVM_TIR_SCHEDULE_END(primitive, level) \ - } \ - catch (const ScheduleError& error) { \ - if ((level) == ScheduleErrorRenderLevel::kDetail) { \ - throw tvm::runtime::Error(error.RenderReport(primitive)); \ - } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ - throw tvm::runtime::Error(error.FastErrorString()); \ - } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ - throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ - } \ +#define TVM_TIR_SCHEDULE_END(primitive, level) \ + } \ + catch (const ScheduleError& error) { \ + if ((level) == ScheduleErrorRenderLevel::kDetail) { \ + throw tvm::runtime::Error(error.RenderReport(primitive) + "\n" + runtime::Backtrace()); \ + } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ + throw tvm::runtime::Error(error.FastErrorString()); \ + } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ + throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ + } \ } /******** Schedule: Schedule: Sampling ********/ diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index cd6e1b4c405ac..a423bdb48afdf 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -196,6 +196,69 @@ def test_meta_schedule_integration_extract_from_bert_base(): assert expected_shape == shape, t.task_name +@requires_torch +def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): + def filter_func(args) -> bool: + from tvm import te, tir + + has_complex_op = False + visited = set() + + def traverse(t): + nonlocal has_complex_op + assert t.handle is not None + if t.handle.value in visited: + return + if isinstance(t.op, te.PlaceholderOp): + pass + elif isinstance(t.op, te.ComputeOp): + has_complex_op = has_complex_op or any( + [isinstance(e, tir.Reduce) for e in t.op.body] + ) + for x in t.op.input_tensors: + traverse(x) + visited.add(t.handle.value) + + for t in args: + traverse(t) + return has_complex_op + + mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) + extracted_tasks = ms.extract_task_from_relay( + mod, + target="llvm", + params=params, + filter_func=filter_func, + ) + expected_task_names = [ + "fused_" + s + for s in [ + "nn_max_pool2d", + "nn_adaptive_avg_pool2d", + "nn_dense_add", + "nn_conv2d_add", + "nn_conv2d_add_1", + "nn_conv2d_add_2", + "nn_conv2d_add_add_nn_relu", + "nn_conv2d_add_add_nn_relu_1", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_nn_relu_1", + "nn_conv2d_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_nn_relu_5", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + ] + ] + + assert len(extracted_tasks) == len(expected_task_names) + for t in extracted_tasks: + assert t.task_name in expected_task_names, t.task_name + + @requires_torch def test_meta_schedule_integration_apply_history_best(): mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) From 0cd4dd2f2d6cab265844de0cb8745e0de8d22571 Mon Sep 17 00:00:00 2001 From: wangxiang2713 <49302617+wangxiang2713@users.noreply.github.com> Date: Wed, 1 Jun 2022 20:58:14 +0800 Subject: [PATCH 006/181] [BugFix] Add lock for ModuleNode::GetFuncFromEnv (#11467) * [BugFix] Add lock for ModuleNode::GetFuncFromEnv * [BugFix] Add lock for ModuleNode::GetFuncFromEnv --- include/tvm/runtime/module.h | 2 ++ src/runtime/module.cc | 1 + 2 files changed, 3 insertions(+) diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 875d999c64fab..31d05571eefd2 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -33,6 +33,7 @@ #include #include +#include #include #include #include @@ -234,6 +235,7 @@ class TVM_DLL ModuleNode : public Object { private: /*! \brief Cache used by GetImport */ std::unordered_map > import_cache_; + std::mutex mutex_; }; /*! diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 57fe57568994b..633dc7c176711 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -107,6 +107,7 @@ std::string ModuleNode::GetSource(const std::string& format) { } const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { + std::lock_guard lock(mutex_); auto it = import_cache_.find(name); if (it != import_cache_.end()) return it->second.get(); PackedFunc pf; From ee26ecf1d516af3c7693f6cb53901b4a055ef9d4 Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Wed, 1 Jun 2022 15:51:56 +0100 Subject: [PATCH 007/181] [microNPU] Add transform matrices and part matcher to identity op (#11453) * [microNPU] Add transform matrices and part matcher to identity op * Address comments * Enable cascader in identity tests * Address comments --- .../contrib/ethosu/cascader/device_config.py | 46 ++++++---- .../backend/contrib/ethosu/te/identity.py | 87 +++++++++++++++++- .../cascader/test_ethosu_identity_matcher.py | 58 ++++++++++++ .../contrib/test_ethosu/test_codegen.py | 89 +++++++++++-------- 4 files changed, 223 insertions(+), 57 deletions(-) create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_identity_matcher.py diff --git a/python/tvm/contrib/ethosu/cascader/device_config.py b/python/tvm/contrib/ethosu/cascader/device_config.py index 27aa8b8c78c59..f654a2598ba41 100644 --- a/python/tvm/contrib/ethosu/cascader/device_config.py +++ b/python/tvm/contrib/ethosu/cascader/device_config.py @@ -48,9 +48,24 @@ def __init__(self, shape: List[int], layout="NHWC"): self.width = int(shape[3]) self.depth = int(shape[2]) * int(shape[4]) else: - self.height = int(shape[1]) - self.width = int(shape[2]) - self.depth = int(shape[3]) + # identity layout is NHWC but the shape is not always 4 + length = len(shape) + if length == 4: + self.height = int(shape[1]) + self.width = int(shape[2]) + self.depth = int(shape[3]) + elif length == 3: + self.height = int(shape[0]) + self.width = int(shape[1]) + self.depth = int(shape[2]) + elif length == 2: + self.height = int(shape[0]) + self.width = int(shape[1]) + self.depth = 1 + elif length == 1: + self.height = int(shape[0]) + self.width = 1 + self.depth = 1 def round_up(self, other: "_Shape"): self.height = _round_up(self.height, other.height) @@ -627,18 +642,19 @@ def _get_subkernel_propagator( stride_w = int(op_attrs.get("stride_w", 1)) transform = ifm_propagator.transform - if input_layout == "NHCWB16": - transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h) - transform[3][-1] = min(transform[3][-1], self._subkernel_limits[1] - stride_w) - else: - transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h) - transform[2][-1] = min(transform[2][-1], self._subkernel_limits[1] - stride_w) - - if op_type in ("ethosu_pooling", "ethosu_depthwise_conv2d"): - if output_layout == "NHCWB16" and input_layout == "NHWC": - transform[3][-1] = depth - elif output_layout == "NHCWB16" and input_layout == "NHCWB16": - transform[2][-1] = 1 + ((depth - 1) // 16) + if op_type != "ethosu_identity": + if input_layout == "NHCWB16": + transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h) + transform[3][-1] = min(transform[3][-1], self._subkernel_limits[1] - stride_w) + else: + transform[1][-1] = min(transform[1][-1], self._subkernel_limits[0] - stride_h) + transform[2][-1] = min(transform[2][-1], self._subkernel_limits[1] - stride_w) + + if op_type in ("ethosu_pooling", "ethosu_depthwise_conv2d"): + if output_layout == "NHCWB16" and input_layout == "NHWC": + transform[3][-1] = depth + elif output_layout == "NHCWB16" and input_layout == "NHCWB16": + transform[2][-1] = 1 + ((depth - 1) // 16) return Propagator(transform, ifm_propagator.offset) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/identity.py b/python/tvm/relay/backend/contrib/ethosu/te/identity.py index 271ca1542fc5c..0b61e0c28b880 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/identity.py @@ -16,7 +16,10 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Tensor Expression for identity""" +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import read_compute, write_compute @@ -56,7 +59,6 @@ def identity_compute( ------- te.Tensor The Output Feature Map tensor. - """ dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale) id_attrs = {"op": "ethosu_identity", "activation": activation} @@ -76,7 +78,86 @@ def identity_compute( name="ethosu_identity", attrs=id_attrs, ) + length = len(ifm.shape) + ifm_matrix = np.identity(length + 1) + offset = np.zeros(length, dtype="int64") + ifm_propagator = Propagator( + ifm_matrix, + offset.tolist(), + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + } + return write_compute(identity, ofm_zero_point, ofm_scale, attrs=propagator_attrs) + + +@register_matcher +def match_ethosu_identity(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU identity. - dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale) + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. - return dmaed_ofm + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + identity = write.op.input_tensors[0] + if identity.op.name != "ethosu_identity": + return None + read = identity.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + input_tensors_shape = input_tensors[0].shape + length = len(input_tensors_shape) + assert length <= 4 + channels = int(input_tensors_shape[length - 1]) if length >= 3 else 1 + + subkernels = len(device_config.get_kernel_steps(identity.op.name, 1, 1, ifm_dtype)) + + input_layout = output_layout = "NHWC" + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + identity.op.attrs, + output_tensor.shape, + channels, + channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + 1, + 1, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + ) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_identity_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_identity_matcher.py new file mode 100644 index 0000000000000..4609a5bc3779a --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_identity_matcher.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.identity import match_ethosu_identity, identity_compute +from .infra import make_matrices + + +def test_ethosu_identity_matcher(): + ofm_channels = 21 + ifm_shape = (1, 12, 15, ofm_channels) + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = identity_compute( + ifm=ifm, + lut=lut, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + activation="NONE", + ) + + length = len(ifm.shape) + ifm_transform = np.identity(length + 1).tolist() + ifm_offset = np.zeros(length, dtype="int64").tolist() + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_identity(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index ce617d14fac2b..b6b78c3357605 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -37,6 +37,10 @@ ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"] +def is_u55_accel_type(accel_type): + return "u55" in accel_type + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 2), (1, 55, 55, 3)]) @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @@ -270,9 +274,7 @@ def binary_elementwise(lhs, rhs): shapes=[ifm_shape, ifm2_shape], ranges=[(0, 1), (0, 2)], accel_type=accel_type, - # non 4D ops legalize into identity op that is not currently supported in the cascader - enable_cascader=(len(ifm_shape) == 4 and len(ifm2_shape) == 4) - and ("u65" not in accel_type), + enable_cascader=is_u55_accel_type(accel_type), ) @@ -301,8 +303,7 @@ def binary_elementwise(lhs, rhs): shapes=[ifm_shape, ifm2_shape], ranges=[(0, 1), (0, 2)], accel_type=accel_type, - # non 4D ops legalize into identity op that is not currently supported in the cascader - enable_cascader=False, + enable_cascader=is_u55_accel_type(accel_type), ) @@ -567,13 +568,12 @@ def generate_output_data(input_data): ethosu_mod = infra.create_ethosu_partition(cpu_mod) infra.compare_ethosu_with_reference( - # identity op is not supported in cascader ethosu_mod, input_data, output_data, accel_type, output_tolerance=1, - enable_cascader=False, + enable_cascader=is_u55_accel_type(accel_type), ) @@ -603,9 +603,12 @@ def create_model(): output_data = generate_ref_data(cpu_mod, input_data) ethosu_mod = infra.create_ethosu_partition(cpu_mod) - # reshape ops legalize into identity op that is not currently supported in the cascader infra.compare_ethosu_with_reference( - ethosu_mod, input_data, output_data, accel_type, enable_cascader=False + ethosu_mod, + input_data, + output_data, + accel_type, + enable_cascader=is_u55_accel_type(accel_type), ) @@ -626,8 +629,9 @@ def test_tflite_slice(accel_type, ifm_shape, begin, size): def slice_func(x): return tf.slice(x, begin, size) - # Ops that get legalized to identity is currently not supported by the cascader - infra.compare_tvm_with_tflite(slice_func, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + slice_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -642,9 +646,8 @@ def test_tflite_strided_slice(accel_type, ifm_shape, begin, end): def strided_slice_func(x): return tf.strided_slice(x, begin, end) - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_tvm_with_tflite( - strided_slice_func, [ifm_shape], accel_type, enable_cascader=False + strided_slice_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) ) @@ -667,12 +670,11 @@ def abs_func(x): op = tf.math.abs(x) return op - # non-4D tensors are legalized to identity which are not supported by the cascader infra.compare_tvm_with_tflite( abs_func, [ifm_shape], accel_type, - enable_cascader=(len(ifm_shape) == 4) and ("u65" not in accel_type), + enable_cascader=is_u55_accel_type(accel_type), ) @@ -752,8 +754,9 @@ def tanh_func(x): op = tf.nn.tanh(x) return op - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + tanh_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -774,7 +777,6 @@ def concat_func(*inputs): op = tf.concat(list(inputs), axis) return op - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False) @@ -788,8 +790,9 @@ def sigmoid_function(x): op = tf.nn.sigmoid(x) return op - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + sigmoid_function, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) # This codegen test checks both, split and split_v @@ -813,7 +816,6 @@ def split_func(x): op = tf.split(x, num_or_size_splits, axis=axis) return op - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_tvm_with_tflite(split_func, [ifm_shape], accel_type, enable_cascader=False) @@ -845,9 +847,12 @@ def create_model(): output_data = generate_ref_data(cpu_mod, input_data) ethosu_mod = partition_for_ethosu(cpu_mod) - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_ethosu_with_reference( - ethosu_mod, input_data, output_data, accel_type, enable_cascader=False + ethosu_mod, + input_data, + output_data, + accel_type, + enable_cascader=is_u55_accel_type(accel_type), ) @@ -860,8 +865,9 @@ def test_tflite_expand_dims(accel_type, ifm_shape, axis): def expand_dims_func(x): return tf.expand_dims(x, axis=axis) - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + expand_dims_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -875,8 +881,9 @@ def test_tflite_squeeze(accel_type, ifm_shape, axis): def squeeze_func(x): return tf.squeeze(x, axis=axis) - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + squeeze_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -894,8 +901,9 @@ def resize_model(x): x, size, align_corners=align_corners, half_pixel_centers=False ) - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + resize_model, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -918,8 +926,9 @@ def resize_model(x): x, size, align_corners=align_corners, half_pixel_centers=False ) - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + resize_model, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -959,9 +968,11 @@ def conv2d_transpose(x): op = tf.nn.bias_add(op, bias) return op - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_tvm_with_tflite( - conv2d_transpose, [ifm_shape], accel_type=accel_type, enable_cascader=False + conv2d_transpose, + [ifm_shape], + accel_type=accel_type, + enable_cascader=is_u55_accel_type(accel_type), ) @@ -982,7 +993,6 @@ def test_tflite_pack(accel_type, ifm_shapes, axis): def pack_func(*inputs): return tf.stack(inputs, axis=axis) - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type, enable_cascader=False) @@ -998,7 +1008,6 @@ def test_tflite_unpack(accel_type, ifm_shape, axis): def unpack_func(x): return tf.unstack(x, axis=axis) - # Ops that get legalized to identity are currently not supported by the cascader infra.compare_tvm_with_tflite(unpack_func, [ifm_shape], accel_type, enable_cascader=False) @@ -1012,8 +1021,9 @@ def test_tflite_leaky_relu(accel_type, ifm_shape, alpha): def leaky_relu_func(x): return tf.nn.leaky_relu(x, alpha=alpha) - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + leaky_relu_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -1045,8 +1055,9 @@ def fully_connected(x): x = tf.nn.relu(x) return x - # Ops that get legalized to identity are currently not supported by the cascader - infra.compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type, enable_cascader=False) + infra.compare_tvm_with_tflite( + fully_connected, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + ) if __name__ == "__main__": From 62e449cb858bde9be0bdd3903f3515916bff0131 Mon Sep 17 00:00:00 2001 From: Mohamad Katanbaf Date: Wed, 1 Jun 2022 09:54:10 -0700 Subject: [PATCH 008/181] [microTVM][ARM]Add tests for arm schedules (#11472) * add more tests for arm_cpu schedules conv1d_ncw, conv1d_nwc, conv2d_NCHWc, depthwise_conv2d_NCHWc, dense_dsp, avg_ pool and max_pool tests are added. Co-authored-by: Mohamad --- .../relay/strategy/arm_cpu/test_avg_pool.py | 168 ++++++++++++++++++ .../relay/strategy/arm_cpu/test_conv1d_ncw.py | 117 ++++++++++++ .../relay/strategy/arm_cpu/test_conv1d_nwc.py | 145 +++++++++++++++ .../strategy/arm_cpu/test_conv2d_NCHWc.py | 138 ++++++++++++++ .../relay/strategy/arm_cpu/test_dense_dsp.py | 90 ++++++++++ .../arm_cpu/test_depthwise_conv2d_NCHWc.py | 121 +++++++++++++ .../relay/strategy/arm_cpu/test_max_pool.py | 135 ++++++++++++++ 7 files changed, 914 insertions(+) create mode 100644 tests/python/relay/strategy/arm_cpu/test_avg_pool.py create mode 100644 tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py create mode 100644 tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py create mode 100644 tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py create mode 100644 tests/python/relay/strategy/arm_cpu/test_dense_dsp.py create mode 100644 tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py create mode 100644 tests/python/relay/strategy/arm_cpu/test_max_pool.py diff --git a/tests/python/relay/strategy/arm_cpu/test_avg_pool.py b/tests/python/relay/strategy/arm_cpu/test_avg_pool.py new file mode 100644 index 0000000000000..31a812b38eed7 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_avg_pool.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicPoolTests: + @tvm.testing.requires_corstone300 + def test_pool( + self, + pool_type, + shape, + dtype, + pool_size, + strides, + padding, + dilation, + layout, + ceil_mode, + count_include_pad, + schedule_name, + ): + """Test a subgraph with a single pool operator.""" + ishape = shape + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + + out0 = getattr(relay.op.nn, pool_type)( + input0, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + layout=layout, + out_layout="", + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + out1 = getattr(relay.op.nn, pool_type)( + input1, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + layout=layout, + out_layout="", + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestAvgPool1d(BasicPoolTests): + """This test is for pool.arm_cpu schedule.""" + + ( + shape, + pool_size, + strides, + padding, + dilation, + layout, + ceil_mode, + count_include_pad, + ) = tvm.testing.parameters( + ((3, 32, 27), (3,), (2,), 0, 1, "NCW", False, False), + ((3, 32, 27), (3,), (2,), 0, 1, "NWC", False, False), + ((3, 32, 27), (3,), (2,), 0, 1, "NCW", True, False), + ((3, 32, 27), (3,), (2,), 1, 1, "NCW", False, True), + ((1, 1, 32), 3, 1, 0, 1, "NCW", False, False), + ((1, 4, 20), 3, 2, 2, 1, "NCW", False, False), + ) + pool_type = tvm.testing.parameter("avg_pool1d") + dtype = tvm.testing.parameter("int32") + schedule_name = tvm.testing.parameter("pool.arm_cpu") + + +class TestAvgPool2d(BasicPoolTests): + """This test is for pool.arm_cpu schedule.""" + + ( + shape, + pool_size, + strides, + padding, + dilation, + layout, + ceil_mode, + count_include_pad, + ) = tvm.testing.parameters( + ((3, 32, 27, 27), (3, 3), (2, 2), 0, 1, "NCHW", False, False), + ((3, 32, 27, 27), (3, 3), (2, 2), 0, 1, "NHWC", False, False), + ((2, 16, 27, 27), (3, 3), (2, 2), 0, 1, "NCHW", True, False), + ((2, 27, 27, 16), (3, 3), (2, 2), 0, 1, "NHWC", True, False), + ((2, 16, 27, 27), (3, 3), (2, 2), 0, 1, "NCHW", True, True), + ((1, 25, 5, 64), (25, 5), (25, 5), 0, 1, "NHWC", False, False), + ((1, 3, 3, 256), (3, 3), (3, 3), 0, 1, "NHWC", False, False), + ((1, 8, 8, 64), (8, 8), (8, 8), 0, 1, "NHWC", False, False), + ((1, 1, 32, 32), (3, 3), 1, 0, 1, "NCHW", False, False), + ((1, 4, 32, 20), (3, 3), (2, 2), 0, 1, "NCHW", False, False), + ) + pool_type = tvm.testing.parameter("avg_pool2d") + dtype = tvm.testing.parameter("int32") + schedule_name = tvm.testing.parameter("pool.arm_cpu") + + +class TestAvgPool3d(BasicPoolTests): + """This test is for pool.arm_cpu schedule.""" + + ( + shape, + pool_size, + strides, + padding, + dilation, + layout, + ceil_mode, + count_include_pad, + ) = tvm.testing.parameters( + ((3, 4, 8, 27, 27), (3, 3, 3), 2, 0, 1, "NCDHW", False, False), + ) + pool_type = tvm.testing.parameter("avg_pool3d") + dtype = tvm.testing.parameter("int32") + schedule_name = tvm.testing.parameter("pool.arm_cpu") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py b/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py new file mode 100644 index 0000000000000..0f0507cfe7d3d --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_conv1d_ncw.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicConv1dTests: + @tvm.testing.requires_corstone300 + def test_conv1d( + self, + data_shape, + kernel_size, + num_filter, + strides, + padding, + dilation, + dtype, + schedule_name, + ): + """Test a subgraph with a single conv1d_ncw operator.""" + ishape = data_shape + wshape = (num_filter, data_shape[1], kernel_size) + + weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype) + + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + weight0 = relay.const(weight_data) + out0 = relay.op.nn.conv1d( + input0, + weight0, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="int32", + out_layout="NCW", + ) + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + weight1 = relay.const(weight_data) + + out1 = relay.op.nn.conv1d( + input1, + weight1, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="int32", + out_layout="NCW", + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestConv1d_ncw(BasicConv1dTests): + """This test is for conv1d_ncw.generic schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + ((4, 32, 16), 3, 12, 1, 0, 1), + ((4, 16, 32), 3, 12, 1, 0, 1), + ((1, 12, 32), 3, 16, 1, 0, 1), + ((3, 10, 12), 4, 24, 1, 0, 1), + ((1, 7, 7), 3, 5, 1, 0, 1), + ((1, 2, 10), 4, 4, 2, (1, 1), 1), + ((1, 2, 20), 4, 4, 2, (0, 1), 1), + ((1, 4, 16), 1, 12, 1, (1, 0), 1), + ((1, 16, 24), 1, 32, 3, (2, 2), 1), + ) + dtype = tvm.testing.parameter("int8", "int16") + data_layout = tvm.testing.parameter("NCW") + schedule_name = tvm.testing.parameter("conv1d_ncw.generic") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py b/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py new file mode 100644 index 0000000000000..e430ade2fac14 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_conv1d_nwc.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicConv1dTests: + @tvm.testing.requires_corstone300 + def test_conv1d( + self, + data_shape, + kernel_size, + kernel_layout, + num_filter, + strides, + padding, + dilation, + dtype, + schedule_name, + ): + """Test a subgraph with a single conv1d_nwc operator.""" + ishape = data_shape + wshape = (kernel_size, data_shape[-1], num_filter) + weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype) + + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + weight0 = relay.const(weight_data) + out0 = relay.op.nn.conv1d( + input0, + weight0, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout="NWC", + kernel_layout="WIO", + out_dtype="int32", + out_layout="NWC", + ) + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + + if kernel_layout == "WOI": + weight1 = relay.const(np.moveaxis(weight_data, 1, -1)) + else: + weight1 = relay.const(weight_data) + + out1 = relay.op.nn.conv1d( + input1, + weight1, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout="NWC", + kernel_layout=kernel_layout, + out_dtype="int32", + out_layout="NWC", + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestConv1d_dsp(BasicConv1dTests): + """This test is for conv1d_dsp schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + ((4, 32, 16), 3, 12, 1, 0, 1), + ((4, 16, 32), 3, 12, 1, 0, 1), + ((4, 32, 16), 3, 12, 1, 0, 1), + ((1, 32, 12), 3, 16, 1, 0, 1), + # TODO: The following 4 tests fail due to https://github.com/apache/tvm/issues/11466 + # ((3, 12, 10), 4, 24, 1, 0, 1), + # ((1, 7, 7), 3, 5, 1, 0, 1), + # ((1, 10, 2), 4, 4, 2, (1, 1), 1), + # ((1, 20, 2), 4, 4, 2, (0, 1), 1), + ((1, 16, 4), 1, 12, 1, (1, 0), 1), + ((1, 24, 16), 1, 32, 3, (2, 2), 1), + ) + dtype = tvm.testing.parameter("int8", "int16") + data_layout = tvm.testing.parameter("NWC") + kernel_layout = tvm.testing.parameter("WOI") + schedule_name = tvm.testing.parameter("conv1d_dsp") + + +class TestConv1d_nwc(BasicConv1dTests): + """This test is for conv1d_nwc.generic schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + ((4, 32, 16), 3, 12, 1, 0, 1), + ((4, 16, 32), 3, 12, 1, 0, 1), + ((4, 32, 16), 3, 12, 1, 0, 1), + ((1, 32, 12), 3, 16, 1, 0, 1), + ((3, 12, 10), 4, 24, 1, 0, 1), + ((1, 7, 7), 3, 5, 1, 0, 1), + ((1, 10, 2), 4, 4, 2, (1, 1), 1), + ((1, 20, 2), 4, 4, 2, (0, 1), 1), + ((1, 16, 4), 1, 12, 1, (1, 0), 1), + ((1, 24, 16), 1, 32, 3, (2, 2), 1), + ) + dtype = tvm.testing.parameter("int8", "int16") + data_layout = tvm.testing.parameter("NWC") + kernel_layout = tvm.testing.parameter("WIO") + schedule_name = tvm.testing.parameter("conv1d_nwc.generic") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py new file mode 100644 index 0000000000000..3b43d37c9075f --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_NCHWc.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicConv2dTests: + @tvm.testing.requires_corstone300 + def test_conv2d_NCHWc( + self, + data_shape, + kernel_size, + data_layout, + kernel_layout, + num_filter, + strides, + padding, + dilation, + dtype, + schedule_name, + ): + """Test a subgraph with a single conv2d_NCHWc operator.""" + ishape = data_shape + wshape = (num_filter, data_shape[1], *kernel_size) + weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype) + + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + weight0 = relay.const(weight_data) + out0 = relay.op.nn.contrib_conv2d_nchwc( + relay.layout_transform(input0, "NCHW", data_layout), + relay.layout_transform(weight0, "OIHW", kernel_layout), + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout=data_layout, + kernel_layout=kernel_layout, + channels=num_filter, + out_dtype="", + out_layout="", + ) + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + weight1 = relay.const(weight_data) + out1 = relay.op.nn.contrib_conv2d_nchwc( + relay.layout_transform(input1, "NCHW", data_layout), + relay.layout_transform(weight1, "OIHW", kernel_layout), + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout=data_layout, + kernel_layout=kernel_layout, + channels=num_filter, + out_dtype="", + out_layout="", + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestConv2d_NCHWc(BasicConv2dTests): + """This test is for conv2d_NCHWc.x86 schedule.""" + + ( + data_shape, + kernel_size, + num_filter, + strides, + padding, + dilation, + dtype, + kernel_layout, + data_layout, + ) = tvm.testing.parameters( + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1), (1, 1), "int16", "OIHW4i4o", "NCHW4c"), + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1), (1, 1), "int32", "OIHW4i4o", "NCHW4c"), + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1), (1, 1), "int8", "OIHW2i8o", "NCHW8c"), + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1), (1, 1), "int16", "OIHW2i8o", "NCHW8c"), + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1), (1, 1), "int32", "OIHW2i8o", "NCHW8c"), + # ResNet18 workloads + # this test does not fit in corstone300 DCTM section. + # ((1, 3, 112, 112), (7, 7), 64, (2, 2), (3, 3), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 64, 28, 28), (3, 3), 64, (1, 1), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 64, 28, 28), (1, 1), 64, (1, 1), (0, 0), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 64, 28, 28), (3, 3), 128, (2, 2), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 64, 28, 28), (1, 1), 128, (2, 2), (0, 0), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 128, 14, 14), (3, 3), 128, (1, 1), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 128, 14, 14), (3, 3), 256, (2, 2), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 128, 14, 14), (1, 1), 256, (2, 2), (0, 0), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 256, 7, 7), (3, 3), 256, (1, 1), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 256, 7, 7), (3, 3), 512, (2, 2), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 256, 7, 7), (1, 1), 512, (2, 2), (0, 0), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ((1, 512, 3, 3), (3, 3), 512, (1, 1), (1, 1), (1, 1), "int8", "OIHW4i4o", "NCHW4c"), + ) + schedule_name = tvm.testing.parameter("conv2d_NCHWc.x86") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py b/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py new file mode 100644 index 0000000000000..3edffba8acaa6 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_dense_dsp.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicDenseTests: + @tvm.testing.requires_corstone300 + def test_dense(self, shape, weight_shape, dtype, schedule_name): + """Test a subgraph with a single dense operator.""" + ishape = shape + wshape = weight_shape + units = weight_shape[0] + weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype) + + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + weight0 = relay.const(weight_data) + out0 = relay.op.nn.dense( + input0, + weight0, + units=units, + out_dtype="int32", + ) + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + weight1 = relay.const(weight_data) + out1 = relay.op.nn.dense( + input1, + weight1, + units=units, + out_dtype="int32", + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestDense(BasicDenseTests): + """This test is for dense_dsp schedule.""" + + shape, weight_shape = tvm.testing.parameters( + ((1, 128), (16, 128)), + ((32, 32), (32, 32)), + ((1, 64), (1, 64)), + ((11, 2), (2, 2)), + ((1, 32), (64, 32)), + ((3, 12), (10, 12)), + ) + dtype = tvm.testing.parameter("int8", "int16") + schedule_name = tvm.testing.parameter("dense_dsp") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py new file mode 100644 index 0000000000000..69e9ab09e4c95 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d_NCHWc.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicConv2dTests: + @tvm.testing.requires_corstone300 + def test_depthwise_conv2d_NCHWc( + self, + data_shape, + kernel_size, + data_layout, + kernel_layout, + groups, + strides, + padding, + dilation, + dtype, + schedule_name, + ): + """Test a subgraph with a single depthwise_conv2d_nchwc operator.""" + ishape = data_shape + wshape = (data_shape[1], 1, *kernel_size) + weight_data = np.random.randint(low=-10, high=10, size=wshape, dtype=dtype) + groups = groups + + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + weight0 = relay.const(weight_data) + out0 = relay.op.nn.contrib_depthwise_conv2d_nchwc( + relay.layout_transform(input0, "NCHW", data_layout), + relay.layout_transform(weight0, "OIHW", kernel_layout), + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout=data_layout, + kernel_layout=kernel_layout, + groups=groups, + out_dtype="", + out_layout="", + ) + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + weight1 = relay.const(weight_data) + out1 = relay.op.nn.contrib_depthwise_conv2d_nchwc( + relay.layout_transform(input1, "NCHW", data_layout), + relay.layout_transform(weight1, "OIHW", kernel_layout), + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=dilation, + data_layout=data_layout, + kernel_layout=kernel_layout, + groups=groups, + out_dtype="", + out_layout="", + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestDepthWiseConv2d_NCHWc(BasicConv2dTests): + """This test is for depthwise_conv2d_NCHWc schedule.""" + + ( + data_shape, + kernel_size, + groups, + strides, + padding, + dilation, + kernel_layout, + data_layout, + ) = tvm.testing.parameters( + ((1, 16, 32, 32), (3, 3), 16, (1, 1), (1, 1, 1, 1), (1, 1), "OIHW1i4o", "NCHW4c"), + ((1, 16, 32, 32), (3, 3), 12, (1, 1), (1, 1, 1, 1), (1, 1), "OIHW1i8o", "NCHW8c"), + ) + dtype = tvm.testing.parameter("int8", "int16", "int32") + schedule_name = tvm.testing.parameter("depthwise_conv2d_NCHWc") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/strategy/arm_cpu/test_max_pool.py b/tests/python/relay/strategy/arm_cpu/test_max_pool.py new file mode 100644 index 0000000000000..f58a041ecb746 --- /dev/null +++ b/tests/python/relay/strategy/arm_cpu/test_max_pool.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from pickle import FALSE +import sys +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, +) + + +class BasicPoolTests: + @tvm.testing.requires_corstone300 + def test_pool( + self, + pool_type, + shape, + dtype, + pool_size, + strides, + padding, + dilation, + layout, + ceil_mode, + schedule_name, + ): + """Test a subgraph with a single max_pool operator.""" + ishape = shape + input0 = relay.var("input", relay.TensorType(ishape, dtype)) + + out0 = getattr(relay.op.nn, pool_type)( + input0, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + layout=layout, + out_layout="", + ceil_mode=ceil_mode, + ) + + ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) + + input1 = relay.var("input", relay.TensorType(ishape, dtype)) + out1 = getattr(relay.op.nn, pool_type)( + input1, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + layout=layout, + out_layout="", + ceil_mode=ceil_mode, + ) + mod = tvm.IRModule.from_expr(relay.Function([input1], out1)) + + inputs = {"input": np.random.randint(low=-128, high=127, size=ishape, dtype=dtype)} + output_list = generate_ref_data(ref_mod, inputs) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + runner=AOT_CORSTONE300_RUNNER, + interface_api="c", + use_unpacked_api=True, + target_opts={ + "-keys": "arm_cpu", + "-mcpu": "cortex-m7", + }, + schedule_name=schedule_name, + ) + + +class TestMaxPool1d(BasicPoolTests): + """This test is for pool.arm_cpu schedule.""" + + shape, pool_size, strides, padding, dilation, layout, ceil_mode = tvm.testing.parameters( + ((3, 32, 27), (3,), (2,), 0, 1, "NCW", True), + ((1, 32, 1), 3, 1, 0, 1, "NWC", False), + ((1, 20, 4), 3, 2, 0, 1, "NWC", False), + ) + pool_type = tvm.testing.parameter("max_pool1d") + dtype = tvm.testing.parameter("int32") + schedule_name = tvm.testing.parameter("pool.arm_cpu") + + +class TestMaxPool2d(BasicPoolTests): + """This test is for pool.arm_cpu schedule.""" + + shape, pool_size, strides, padding, dilation, layout, ceil_mode = tvm.testing.parameters( + ((2, 32, 27, 27), (3, 3), (2, 2), 0, 1, "NCHW", False), + ((2, 32, 27, 27), (3, 3), (2, 2), 0, 1, "NCHW", True), + ((1, 26, 26, 12), (2, 2), (2, 2), 0, 1, "NHWC", False), + ((1, 11, 11, 32), (2, 2), (2, 2), 0, 1, "NHWC", False), + ((1, 3, 3, 64), (2, 2), (2, 2), 0, 1, "NHWC", False), + ((1, 32, 32, 1), (3, 3), 1, 0, 1, "NHWC", False), + ((1, 32, 20, 4), (3, 3), (2, 2), 0, 1, "NHWC", False), + ((1, 32, 32, 1), (3, 3), 1, 0, 1, "NHWC", True), + ((1, 32, 20, 4), (3, 3), (2, 2), 0, 1, "NHWC", True), + ) + pool_type = tvm.testing.parameter("max_pool2d") + dtype = tvm.testing.parameter("int32") + schedule_name = tvm.testing.parameter("pool.arm_cpu") + + +class TestMaxPool3d(BasicPoolTests): + """This test is for pool.arm_cpu schedule.""" + + shape, pool_size, strides, padding, dilation, layout, ceil_mode = tvm.testing.parameters( + ((3, 4, 8, 27, 27), (3, 3, 3), 2, 0, 1, "NCDHW", False), + ) + pool_type = tvm.testing.parameter("max_pool3d") + dtype = tvm.testing.parameter("int32") + schedule_name = tvm.testing.parameter("pool.arm_cpu") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 89c02358a13f2e744580c4615bfeb06962d71965 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Wed, 1 Jun 2022 11:51:35 -0700 Subject: [PATCH 009/181] [Relay] Plumb external codegen target via Target.current() (#11432) * [Relay] Plumb external codegen target via Target.current() for all external codegen paths (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). We want both old-style (via relay.ext.$toolchain) and new-style (via "RelayToTIR" Pass attribute on target kind) external codegen to be able to access the current 'external codegen' Target instance via Target.current(). - For old-style, plumb the true Target through TEComplier and push it on the context stack before calling relay.ext.$toolchain. - For new-style, pass the CompilationConfig to the RelayToTIRTargetHook pass, make the jump from "Compiler" attribute value to Target via the new CompilationConfig::FindPrimitiveTargetForKind method, and push on the stack before invoking the custom "RelayToTIR" pass. While working on this discovered RelayToTIRTargetHook was incompatible with the VM's compilation flow since RelayToTIRTargetHook assumes all "Compiler" attributed functions are inlined. Generalize it to support both inline and global function styles. Extend Target::IsExternalCodegen to recognize target kinds with "RelayToTIR" attributes as external. Update target hooks unit test to exercise new support for outline-style, picking up the current target, and compiling via the VM. * - A bit of polishing en passant. * - Add comment as per Josh's suggestion Can't repro tests/python/contrib/test_ethosu/cascader/test_scheduler.py::test_compute_cycles_annotation failure, flake? --- include/tvm/relay/transform.h | 43 +++- include/tvm/target/target_kind.h | 10 + src/relay/backend/aot_executor_codegen.cc | 2 +- src/relay/backend/contrib/cmsisnn/target.cc | 2 +- .../backend/contrib/codegen_c/codegen.cc | 12 ++ src/relay/backend/contrib/ethosu/codegen.cc | 2 +- .../example_target_hooks/relay_to_tir.cc | 200 +++++++++++++----- .../contrib/example_target_hooks/target.cc | 5 +- src/relay/backend/graph_executor_codegen.cc | 2 +- src/relay/backend/interpreter.cc | 8 +- src/relay/backend/te_compiler.cc | 57 ++--- src/relay/backend/te_compiler.h | 11 +- src/relay/backend/vm/compiler.cc | 34 +-- src/relay/backend/vm/compiler.h | 4 +- src/relay/transforms/dead_code.cc | 2 + src/relay/transforms/inline.cc | 1 + src/relay/transforms/target_hooks.cc | 150 ++++++++++--- src/target/target.cc | 8 +- tests/cpp/target_test.cc | 6 + tests/python/frontend/onnx/test_forward.py | 2 +- .../relay/dyn/test_dynamic_op_level2.py | 4 +- tests/python/relay/test_external_codegen.py | 54 +++++ tests/python/relay/test_target_hooks.py | 53 ++++- tests/python/relay/utils/external_codegen.py | 2 +- 24 files changed, 512 insertions(+), 162 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 0d518e4ed547e..6e3bddf9adf5c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -462,11 +462,50 @@ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); TVM_DLL Pass SimplifyExpr(); /*! - * \brief Run any registered RelayToTIR passes registered on the functions in a module. + * \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds. + * + * This pass looks for inline, let-bound or global functions which have a "Compiler" attribute. + * If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the + * 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole. + * + * If, in addition, the \p config has a Target with a matching TargetKind, that Target is set + * as the 'current' target before the custom pass is executed. In this way it is possible + * for custom passes to pick up target options which may guide how they transform the IRModule. + * (Those targets are referred to as 'extern codegen targets' elsewhere). + * + * A typical custom pass will: + * - Find calls to "Compiler" attributes functions with matching compiler name. + * - Lower those function to TIR PrimFuncs. + * - Bind those functions into the IRModule under the the functions' "global_symbol" attribute. + * - Replace all calls to those functions with 'call_lowered' to the matching global. + * Care should be taken to handle multiple calls to the same function. + * See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass. + * + * It is also possible (despite the pass and attribute names!) for the custom pass to proceed + * directly to a runtime::Module, which can be attached to the output IRModules "external_mods" + * attribute (taking care not to clobber any existing modules). In this case the flow is as above, + * except: + * - The runtime::Module must contain a binding for each compiled function under their + * "global_symbol" (ie runtime::Module::ImplementsFunction should return true). + * - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same + * "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body + * should be the original function body. In this way we always have a TVM definition matching + * every global function name. + * + * There are many existing runtime::Modules, ranging from source to object to dynamic libaries to + * entirely custom implementations. Some of those may require additional compilation using + * 'export_library' on the final build artifact. + * + * The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility + * passes can be used by custom passes to take care of some of the boilerplate. + * + * TODO(mbs): Rename PreLoweringTargetHooks? + * + * \param config All available targets. * * \return The pass. */ -TVM_DLL Pass RelayToTIRTargetHook(); +TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config); /*! * \brief A pass for manifesting explicit memory allocations and rewriting diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 395d3aab6757b..4879470e76545 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -402,6 +402,16 @@ namespace attr { * See also \p Target::IsExternalCodegenFor */ constexpr const char* kIsExternalCodegen = "is_external_codegen"; + +/*! + * \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name + * also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass + * to apply before the TVM lowering. + * + * See also \p Target::IsExternalCodegenFor + */ +constexpr const char* kRelayToTIR = "RelayToTIR"; + } // namespace attr /*! diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 60f108aacf662..167afd2c5f782 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1079,7 +1079,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); }, - config_->host_virtual_device)(mod); + config_)(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 99bc0bc7cb205..fd2f18aa9905b 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -31,7 +31,7 @@ tvm::transform::Pass RelayToTIR(); runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) - .set_attr("RelayToTIR", RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace cmsisnn diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 19b8c579cd8b5..fd1c39bb92830 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -227,6 +227,14 @@ class CSourceCodegen : public CSourceModuleCodegenBase { Array variables = std::get<0>(res); String func_name = std::get<1>(res); + Optional opt_target = Target::Current(); + if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") { + Optional header = opt_target.value()->GetAttr("header"); + if (header.defined() && !header.value().empty()) { + code_stream_ << header.value().c_str() << "\n"; + } + } + // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -293,6 +301,10 @@ runtime::Module CCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); +TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .add_attr_option("header", String("")); // value is prepended to every output CModule + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 47c80b47c5790..afa17750d8a8c 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -320,7 +320,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr("RelayToTIR", RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace ethosu diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c498baa6d11d2..eb6cf1cce4207 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -28,12 +28,37 @@ #include #include "../../../op/call/call.h" +#include "tvm/tir/function.h" namespace tvm { namespace relay { namespace contrib { namespace example_target_hooks { +namespace { + +/*! + * \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay + * Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a + * TIR PrimFunc implementing subtraction. + * + * Illustrates six aspects a custom 'lowering' style pass may need to account for: + * - Lowerable functions can appear inline as call ops, bound to let-bound variables, or as + * global functions. + * - Let-bound lowerable functions should be inlined on-the-fly since after processing the + * let-binding is no longer required. + * - There may be multiple calls to the same lowerable function. All calls need to be + * rewritten, even though the function itself need be rewritten only once. + * - GlobalVars must be shared between all calls and the new definition itself. + * - Calls to lowered functions must use the "call_lowered" calling convention. + * - The Target::Current() may hold an instance of the TargetKind from which the custom Pass + * was extracted. + * + * Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add + * runtime::Modules to the output IRModule's "external_mods" attribute. In this case the + * IRModule must be left with an 'extern' Function definition with the matching "external_symbol" + * name. + */ class ConvertAddToSubtract : public MixedModeMutator { public: explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) @@ -56,51 +81,102 @@ class ConvertAddToSubtract : public MixedModeMutator { return tir::BufferLoad(buffer, {index}); } - void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { - tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); - tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); - tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); + GlobalVar ReplaceAddWithSubtractPrimFunc(const Function& func) { + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + ICHECK(func_name.defined()); - tir::Var x_var("x", DataType::Handle()); - tir::Var y_var("y", DataType::Handle()); - tir::Var out_var("out", DataType::Handle()); + // -------------------------------------------------------------------------------------------- + // Cases: + // - Inline function: + // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call. + // - Thereafter (via object sharing): discover global var already in module, replace call + // - Global function: + // - Assume func_name == global_var->name_hint + // - First encounter: create global var, rewrite to PrimFunc, update binding, replace call + // - Thereafter (via global var): discover global var already in module, replace call + // -------------------------------------------------------------------------------------------- - Map dict_attrs; - dict_attrs.Set("global_symbol", new_global_var->name_hint); - dict_attrs.Set("tir.noalias", Bool(true)); + // If necessary, introduce a new global var to map the function to and copy the source type + // over for InferType. + GlobalVar global_var; + bool need_rewriting; + if (ir_module_->ContainGlobalVar(func_name.value())) { + global_var = ir_module_->GetGlobalVar(func_name.value()); + // Only rewrite to a PrimFunc if the global definition is still a Relay function. + need_rewriting = ir_module_->Lookup(global_var)->IsInstance(); + } else { + global_var = GlobalVar(func_name.value()); + global_var->checked_type_ = func->checked_type(); + need_rewriting = true; + } - te::Var index("index", DataType::Int(32)); - tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); - tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); - tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); + // For illustration only, check if the current target matches the example_target_hook kind, + // and if so extract the example attribute value. + int64_t example_attribute_value = 0; + Optional opt_current_target = Target::Current(); + if (opt_current_target.defined() && + opt_current_target.value()->kind->name == "example_target_hook") { + example_attribute_value = + opt_current_target.value()->GetAttr("example_attribute").value()->value; + } - Map buffer_map = { - {x_var, x_buffer}, - {y_var, y_buffer}, - {out_var, out_buffer}, - }; + if (need_rewriting) { + // The called function is still in Relay form. Convert to TIR. + tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); + tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); + tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); - tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), - buffer_map, {}, DictAttrs(dict_attrs)); + tir::Var x_var("x", DataType::Handle()); + tir::Var y_var("y", DataType::Handle()); + tir::Var out_var("out", DataType::Handle()); - // Switch to TIRToRuntime hook for testing - Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); - if (tir_to_runtime) { - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); - } else { - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + Map dict_attrs; + dict_attrs.Set("global_symbol", global_var->name_hint); + dict_attrs.Set("tir.noalias", Bool(true)); + + te::Var index("index", DataType::Int(32)); + tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); + if (example_attribute_value > 0) { + // For illustration only, fold the example attribute into the result. + indexed_sub = tir::Sub(indexed_sub, FloatImm(DataType::Float(32), + static_cast(example_attribute_value))); + } + + tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index}); + tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); + + Map buffer_map = { + {x_var, x_buffer}, + {y_var, y_buffer}, + {out_var, out_buffer}, + }; + + tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), + buffer_map, {}, DictAttrs(dict_attrs)); + + // Switch to TIRToRuntime hook for testing + Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); + if (tir_to_runtime) { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); + } else { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + } + + ir_module_->Update(global_var, replacement_func); // Will Add if global_var is new. } - ir_module_->Add(new_global_var, replacement_func); + return global_var; } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { Expr var = this->VisitExpr(op->var); Expr value = this->VisitExpr(op->value); - // Outlineable function no longer needs let binding - if (this->CanLowerExpr(value)) { + if (AsLowerableFunction(value)) { + // Inline on-the-fly if the let-bound value is lowerable. this->memo_[var] = value; } }; @@ -110,8 +186,8 @@ class ConvertAddToSubtract : public MixedModeMutator { Expr body = this->VisitExpr(op->body); auto expr = GetRef(op); - // Drop the let binding - if (this->CanLowerExpr(value)) { + if (AsLowerableFunction(value)) { + // The let binding is no longer needed since inlined on-the-fly above. this->memo_[expr] = this->VisitExpr(op->body); } else { Var var = Downcast(this->VisitExpr(op->var)); @@ -126,39 +202,49 @@ class ConvertAddToSubtract : public MixedModeMutator { return memo_[GetRef(op)]; } - bool CanLowerExpr(const Expr& expr) { - const auto* func = expr.as(); - if (func == nullptr) { - return false; - } - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - if (!func_name.defined()) { - return false; + const FunctionNode* AsLowerableFunction(const Expr& expr) { + if (const auto* function_node = expr.as()) { + auto func_name = function_node->GetAttr(::tvm::attr::kGlobalSymbol); + if (!func_name.defined()) { + return nullptr; + } + if (func_name != "replace_add_with_subtract") { + return nullptr; + } + return function_node; + } else if (const auto* global_var_node = expr.as()) { + return AsLowerableFunction(ir_module_->Lookup(GetRef(global_var_node))); + } else { + return nullptr; } - if (func_name != "replace_add_with_subtract") { - return false; + } + + const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) { + if (const auto* global_var_node = expr.as()) { + if (ir_module_->Lookup(GetRef(global_var_node)).as()) { + return global_var_node; + } } - return true; + return nullptr; } Expr Rewrite_(const CallNode* pre, const Expr& post) override { - if (const CallNode* call = post.as()) { - if (CanLowerExpr(call->op)) { - auto* func = call->op.as(); - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - - // Introduce a new global var to map the function to and copy the source type - // over for InferType - GlobalVar new_global_var(func_name.value()); - new_global_var->checked_type_ = func->checked_type(); - ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - + if (const auto* call = post.as()) { + GlobalVar new_op; + if (const auto* function_node = AsLowerableFunction(call->op)) { + // Add or replace the function with a PrimFunc. + new_op = ReplaceAddWithSubtractPrimFunc(GetRef(function_node)); + } else if (const auto* global_var_node = AsAlreadyLoweredFunction(call->op)) { + // The function has already been rewritten, so we just need to update the call. + new_op = GetRef(global_var_node); + } + if (new_op.defined()) { // Since we are replacing the Relay function with a call to a TIR function, we must use // the call_lowered op. CallLoweredAttrs attrs; attrs.metadata.Set("relay_attrs", call->attrs); ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; - return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span); + return CallLowered(std::move(new_op), call->args, std::move(attrs), call->span); } } @@ -171,10 +257,12 @@ class ConvertAddToSubtract : public MixedModeMutator { Target custom_target_; }; +} // namespace + transform::Pass RelayToTIR() { runtime::TypedPackedFunc pass_func = [=](IRModule ir_module, transform::PassContext pass_context) { - auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c")); + ConvertAddToSubtract relay_to_tir(std::move(ir_module), Target("c")); return relay_to_tir.Mutate(); }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 6f1914eac4c3a..19bfa8c682986 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -34,7 +34,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) - .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); + .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 2734439cddbdc..7dba23803f8c7 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -232,7 +232,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); }, - config_->host_virtual_device)(mod); + config_)(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 65ef296516956..9661040eab308 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -945,14 +945,13 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -IRModule Prepare(IRModule mod, CompilationConfig config) { - VirtualDevice host_virtual_device = config->host_virtual_device; +IRModule Prepare(IRModule mod, const CompilationConfig& config) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), qnn::transform::Legalize(), // Figure out which devices should be used to execute. // TODO(mbs): Should ignore all existing annotations when constant folding - transform::PlanDevices(std::move(config)), + transform::PlanDevices(config), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -962,8 +961,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, - std::move(host_virtual_device))}); + tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 76dbfef5386dd..73b44f7361a57 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -299,11 +299,10 @@ class TECompilerImpl : public TECompilerNode { // the module's globals. Furthermore, the external codegen tool must bind the compiled // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName // here. - auto target = Target("ext_dev"); auto global_var = GlobalVar(opt_global_symbol.value()); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr}, + value->cached_func = CachedFunc(key->target, global_var, {}, {}, te::Schedule{nullptr}, tir::PrimFunc{nullptr}, {}, ir_module); // Collect these here as it's removed in LowerExternalFunctions() device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value()); @@ -531,14 +530,14 @@ using AnalysisRemapping = std::unordered_maptarget); + CCacheKey shape_key(func, config_->host_virtual_device->target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -733,7 +732,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Special case: device_copies are left as calls to primitive operators // (thus undoing FuseOps) so that each backend can handle them directly. - // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone. + // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy + // alone. if (const auto* function_node = primitive_func.as()) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); if (device_copy_props.body.defined()) { @@ -771,10 +771,18 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Typical case: call to fused primitive Relay Function. // Find the desired target device. Target target; - if (primitive_func->GetAttr(attr::kCompiler).defined()) { - // The generic 'external device' target. - // TODO(mbs): Retire once replaced unified BYOC compiler and target machinery - target = Target("ext_dev"); + Optional opt_compiler = primitive_func->GetAttr(attr::kCompiler); + if (opt_compiler.defined()) { + // This function needs to be compiled with external codegen. + Optional opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); + if (opt_target.defined()) { + // The target is what's supplied by the compilation config for kind matching the + // "Compiler" name. + target = opt_target.value(); + } else { + // Legacy fallback. + target = Target("ext_dev"); + } } else { // The target corresponding to the call_node expression's annotation. VirtualDevice virtual_device = GetVirtualDevice(GetRef(call_node)); @@ -791,6 +799,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { IRModule module_; ProcessFn process_fn_; + /*! \brief All available targets. */ + CompilationConfig config_; // Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which // have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered // primitive function as we go. Those vars will be bound in the target device-type specific @@ -799,21 +809,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { std::unordered_map primitive_functions_; String module_name_; TECompiler compiler_; - /*! - * \brief The \p VirtualDevice for the host, which is where all shape-related data and computation - * must live. - */ - VirtualDevice host_virtual_device_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; }; Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn, - VirtualDevice host_virtual_device) { + CompilationConfig config) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler, - host_virtual_device); + LowerTensorExprMutator lower_te(module, process_fn, config, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -1043,7 +1047,7 @@ void UpdateFunctionMetadata(BaseFunc func, } IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device) { + CompilationConfig config) { TECompiler compiler(module); // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten @@ -1058,8 +1062,8 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // GlobalVar, and calls updated (sticking with regular Relay Call). // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated // (using call_lowered convention). - IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn), - std::move(host_virtual_device))(module); + IRModule updated_module = + LowerTensorExpr(module_name, compiler, std::move(process_fn), std::move(config))(module); // The Functions tagged with "Compiler" are now residing in the cache ready to be // compiled by LowerExternalFunctions. However we still need a record of them in the @@ -1159,15 +1163,14 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device) { +Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig complilation_config) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, module_name, process_fn, host_virtual_device); + return LowerTE(module, module_name, process_fn, complilation_config); }; return tvm::transform::Sequential( - {tvm::relay::transform::RelayToTIRTargetHook(), + {tvm::relay::transform::RelayToTIRTargetHook(complilation_config), tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(), tvm::tir::transform::ExtractPrimFuncConstants()}); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 0b2288d6a156f..8312a20cb862b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -189,7 +189,8 @@ IRModule LowerTE( const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](BaseFunc f) {}); -/*! \brief Pass to lower an IRModule's primitive functions to TIR. +/*! + * \brief Pass to lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" * to TE expressions, schedules them, and then to TIR. It annotates all functions @@ -198,11 +199,11 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \param host_virtual_device \p VirtualDevice for host data and computations - * \returns The pass which lowers primative functions to TIR + * \param config All available targets. + * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn, - VirtualDevice host_virtual_device); +transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config); + } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 5a62ac66f7365..e0b742a840906 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -523,11 +523,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { op_index = itr->second; } - // Capture the dictionary of attributes from the original primitive function so that they - // can contribute to the hash of the compiled primitive. This way we can distinguish primitives - // with the same body expression but different attributes which may arbitrarily influence code - // generation. - op_attrs[op_index] = attrs->dict; + if (attrs.defined() && attrs->dict.defined()) { + // Capture the dictionary of attributes from the original primitive function so that they + // can contribute to the hash of the compiled primitive. This way we can distinguish + // primitives with the same body expression but different attributes which may arbitrarily + // influence code generation. + op_attrs[op_index] = attrs->dict; + } Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), argument_registers)); @@ -981,25 +983,25 @@ void VMCompiler::LowerImpl(IRModule mod) { } } -transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_device) { +transform::Sequential VMCompiler::MemoryOpt(const CompilationConfig& config) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower any new shape functions and device_copies. - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Fuse & lower any new allocations. - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is // incomplete to provide memory resuse optimizations. Disable it until we can @@ -1011,10 +1013,10 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower yet again - pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); + pass_seqs.push_back(FuseAndLowerOperators(config)); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); + pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -1030,7 +1032,7 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de return transform::Sequential(std::move(pass_seqs)); } -transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& host_virtual_device) { +transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& config) { Array pass_seqs; // Hoist operators to "primitive" Functions. pass_seqs.push_back(FuseOps()); @@ -1043,7 +1045,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& hos backend::UpdateConstants(func, ¶ms_); } }, - host_virtual_device)); + config)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); @@ -1094,7 +1096,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { backend::UpdateConstants(func, ¶ms_); } }, - config_->host_virtual_device)); + config_)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. @@ -1111,7 +1113,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(config_->host_virtual_device)); + pass_seqs.push_back(MemoryOpt(config_)); pass_seqs.push_back(transform::InferType()); transform::Sequential seq(pass_seqs); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index a65bdc5ab3cb6..163ec399013b0 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -146,10 +146,10 @@ class VMCompiler : public runtime::ModuleNode { IRModule OptimizeModuleImpl(IRModule mod); /*! \brief Returns the passes which layout memory. */ - transform::Sequential MemoryOpt(const VirtualDevice& host_virtual_device); + transform::Sequential MemoryOpt(const CompilationConfig& config); /*! \brief Returns the passes which fuse then lower Relay primitive operators. */ - transform::Sequential FuseAndLowerOperators(const VirtualDevice& host_virtual_device); + transform::Sequential FuseAndLowerOperators(const CompilationConfig& config); /*! * \brief Populate the global function names in a map where the value is used diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index ca1e04ae59fac..45cb8271b0746 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -534,6 +534,7 @@ namespace transform { // Declared in relay/transform.h Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); // Which let bindings are pure and can be safely elided? std::unordered_map var_to_purity; if (!ignore_impurity) { @@ -566,6 +567,7 @@ Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { result->Add(kv.first, kv.second); } } + VLOG(1) << "After:" << std::endl << PrettyPrint(result); return result; }; diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index a6e26364bbc4f..c55b6778093e5 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -69,6 +69,7 @@ class Inliner : ExprMutator { for (auto arg : vanilla_call->args) { new_args.push_back(VisitExpr(arg)); } + // TODO(mbs): Does not handle multiple calls to the same global function. cur_node_->RemoveCallTo(gv); return MakeNewExpr(gv, new_args, GetRef(call_node)); } diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 0022baf881ba0..00953a1907e13 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -30,61 +30,143 @@ namespace tvm { namespace relay { namespace transform { -class TargetHookVisitor : public tvm::relay::MixedModeVisitor { - /*! \brief Collected pass list for all nodes */ - std::vector pass_list_; - /*! \brief Attribute map for all registered targets */ - TargetKindAttrMap target_attr_map_; - using tvm::relay::MixedModeVisitor::VisitExpr_; +namespace { + +/*! + * \brief A pass extracted from a target kind's "RelayToTIR" attribute, along with any + * 'external codegen' Target instance with matching kind name which should be current when + * the pass is applied. + */ +struct CustomPass { + std::string target_kind_name; + Pass pass; + Optional opt_target; + CustomPass(std::string target_kind_name, Pass pass, Optional opt_target) + : target_kind_name(std::move(target_kind_name)), + pass(std::move(pass)), + opt_target(std::move(opt_target)) {} +}; + +/*! + * \brief Collect all the \p CustomPasses needed according to the "Compiler" attributes on + * inlined or global functions. + */ +class TargetHookVisitor : public MixedModeVisitor { public: - TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap("RelayToTIR")) {} + TargetHookVisitor(IRModule mod, CompilationConfig config) + : mod_(std::move(mod)), + config_(std::move(config)), + target_attr_map_(tvm::TargetKind::GetAttrMap(tvm::attr::kRelayToTIR)) {} - std::vector Visit(const IRModule& ir_mod) { - for (const auto& it : ir_mod->functions) { - if (const auto* function_node = it.second.as()) { + std::vector Visit() { + ICHECK(custom_passes_.empty()); + // To ensure the passes are run in a deterministic order we'll search for functions in + // lexicographic order. + std::vector> functions; + for (const auto& kv : mod_->functions) { + functions.emplace_back(kv.first->name_hint, kv.second); + } + std::sort(functions.begin(), functions.end()); + for (const auto& kv : functions) { + if (const auto* function_node = kv.second.as()) { + // May be a top-level function with a "Compiler" attribute. + MaybeAddPassForFunction(function_node); + } + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + // May have calls to inlined "Compiler" functions in body. VisitExpr(GetRef(function_node)); } } - return pass_list_; + return std::move(custom_passes_); } - void VisitExpr_(const LetNode* op) final { - auto pre_visit = [this](const LetNode* op) { - this->VisitExpr(op->var); - this->VisitExpr(op->value); + private: + using tvm::relay::MixedModeVisitor::VisitExpr_; + + void VisitExpr_(const LetNode* let_node) final { + auto pre_visit = [this](const LetNode* inner_let_node) { + this->VisitExpr(inner_let_node->var); + this->VisitExpr(inner_let_node->value); }; - auto post_visit = [this](const LetNode* op) { - this->VisitExpr(op->body); - this->visit_counter_[op] += 1; + auto post_visit = [this](const LetNode* inner_let_node) { + this->VisitExpr(inner_let_node->body); + this->visit_counter_[inner_let_node] += 1; }; - ExpandANormalForm(op, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); + } + + void VisitExpr_(const FunctionNode* function_node) override { + ExprVisitor::VisitExpr_(function_node); + MaybeAddPassForFunction(function_node); } - void VisitExpr_(const FunctionNode* func) override { - ExprVisitor::VisitExpr_(func); - if (!func->GetAttr(attr::kCompiler).defined()) { + /*! + * \brief If \p function_node has a "Compiler" attribute, checks if we should include a + * matching custom pass. Otherwise no-op. + */ + void MaybeAddPassForFunction(const FunctionNode* function_node) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (!opt_compiler) { + // No external codegen required. return; } - String code_gen_name = func->GetAttr(attr::kCompiler).value(); - Optional target_kind = tvm::TargetKind::Get(code_gen_name); - if (!target_kind || !target_attr_map_.count(target_kind.value())) { + // First cross-over: use "Compiler" attribute name as target kind. + std::string kind_name = opt_compiler.value(); + Optional opt_target_kind = tvm::TargetKind::Get(kind_name); + if (!opt_target_kind || !target_attr_map_.count(opt_target_kind.value())) { + // Target kind does not exist or have the "RelayToTIR" attribute, no custom pass to consider. return; } - Pass custom_target_pass = target_attr_map_[target_kind.value()]; - if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { - pass_list_.push_back(custom_target_pass); + if (!seen_kinds_.emplace(kind_name).second) { + // Already accounted for custom pass. + return; } + // Second (optional) cross-over: find unique Target instance in overall available targets with + // the same kind so that it can be made available when custom pass is invoked. + Optional opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); + Pass custom_target_pass = target_attr_map_[opt_target_kind.value()]; + custom_passes_.emplace_back(std::move(kind_name), std::move(custom_target_pass), + std::move(opt_target)); } + + /*! \brief IRModule we are visiting. */ + IRModule mod_; + /*! \brief All available targets. */ + CompilationConfig config_; + /*! \brief Cached attribute map for all registered targets */ + TargetKindAttrMap target_attr_map_; + /*! \brief Which target kind names have already contributed to the custom passes list. */ + std::unordered_set seen_kinds_; + /*! + * \brief All the custom passes to run, paired with their corresponding target instances, if any. + */ + std::vector custom_passes_; }; -Pass RelayToTIRTargetHook() { - auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { - auto target_hook_visitor = TargetHookVisitor(); - std::vector pass_list = target_hook_visitor.Visit(mod); - Sequential run_hooks(pass_list); +} // namespace - return run_hooks(mod); +Pass RelayToTIRTargetHook(CompilationConfig config) { + auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { + VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); + TargetHookVisitor target_hook_visitor(mod, config); + std::vector custom_passes = target_hook_visitor.Visit(); + for (const auto& custom_pass : custom_passes) { + if (custom_pass.opt_target.defined()) { + VLOG(0) << "Invoking custom pass for target " + << custom_pass.opt_target.value()->ToDebugString(); + // Push the target on the stack. + With with_target(custom_pass.opt_target.value()); + // Invoke the pass with target in scope. + mod = custom_pass.pass(mod); + } else { + // Invoke the pass. + VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'"; + mod = custom_pass.pass(mod); + } + } + VLOG(1) << "After:" << std::endl << PrettyPrint(mod); + return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); } diff --git a/src/target/target.cc b/src/target/target.cc index 75126ed11c70a..3cdfa0cc0d5e8 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -495,8 +495,12 @@ Target::Target(TargetKind kind, Optional host, String tag, Array attr_map = TargetKind::GetAttrMap(::tvm::attr::kIsExternalCodegen); - return attr_map.get(get()->kind, Bool(false)); + TargetKindAttrMap is_external_codegen_map = + TargetKind::GetAttrMap(tvm::attr::kIsExternalCodegen); + TargetKindAttrMap relay_to_tir_map = + TargetKind::GetAttrMap(tvm::attr::kRelayToTIR); + return is_external_codegen_map.get(get()->kind, Bool(false)) || + relay_to_tir_map.count(get()->kind); } bool Target::IsExternalCodegenFor(const Target& that) const { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index b657ac0c5783d..2c85e47e7fb89 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -144,16 +145,21 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); +TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) + .set_attr(tvm::attr::kRelayToTIR, tvm::relay::transform::InferType()); + TEST(Target, ExternalCodegen) { Target regular("cuda"); Target external0("test_external_codegen_0"); Target external1("test_external_codegen_1"); Target external2("test_external_codegen_2"); + Target external3("test_external_codegen_3"); ASSERT_FALSE(regular.IsExternalCodegen()); ASSERT_TRUE(external0.IsExternalCodegen()); ASSERT_TRUE(external1.IsExternalCodegen()); ASSERT_TRUE(external2.IsExternalCodegen()); + ASSERT_TRUE(external3.IsExternalCodegen()); ASSERT_TRUE(external0.IsExternalCodegenFor(regular)); ASSERT_FALSE(regular.IsExternalCodegenFor(external0)); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 41123a2548256..dbc5147e20300 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -6653,4 +6653,4 @@ def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index a017762ce35db..690ddcac8d512 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -208,6 +208,4 @@ def verify_pad_default_fill(dshape, pad_width, dtype): if __name__ == "__main__": - test_dyn_pad() - test_dyn_upsampling_infer_type_const() - test_dyn_upsampling_run() + tvm.testing.main() diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c5a9041b15fe4..4f451a125184d 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -31,6 +31,8 @@ set_external_func_attr, parametrize_external_codegen_checks, parametrize_external_json_codegen_checks, + check_graph_executor_result, + check_vm_result, ) @@ -180,6 +182,58 @@ def test_extern_gcc(check_result): check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data)) +# TODO(mbs): The check_aot_executor_result does not support the list-of-targets, mostly because +# tvm.testing.aot.compile_and_run requires the target to be a kind name string, and +# tvm.testing.aot.compile_models requires a single Target object. However, code outside of +# tvm.testing.aot is ready for this more general form. +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_extern_gcc_with_target_instance(check_result): + shape = (8, 8) + dtype = "int32" + + def make_mod(): + x0 = relay.var("x0", shape=shape, dtype=dtype) + y0 = relay.var("y0", shape=shape, dtype=dtype) + z = x0 + y0 + f = relay.Function([x0, y0], z) + f = set_external_func_attr(f, "ccompiler", "ccompiler_0") + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.var("y", shape=shape, dtype=dtype) + call = relay.Call(f, [x, y]) + return tvm.IRModule.from_expr(call) + + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("llvm", host=host_target) + # The header attribute is just whitespace, so compilation is as usual. + good_extern_codegen_target = tvm.target.Target( + {"kind": "ccompiler", "header": "// Good"}, host=host_target + ) + # The header attribute is ill-formed, so compilation is expected to fail. + bogus_extern_codegen_target = tvm.target.Target( + {"kind": "ccompiler", "header": "Bogus"}, host=host_target + ) + + mod = make_mod() + + x_data = np.random.rand(*shape).astype(dtype) + y_data = np.random.rand(*shape).astype(dtype) + expected_result = x_data + y_data + inputs = {"x": x_data, "y": y_data} + + check_result( + mod, inputs, shape, expected_result, target=[generic_target, good_extern_codegen_target] + ) + + with pytest.raises(RuntimeError): + check_result( + mod, + inputs, + shape, + expected_result, + target=[generic_target, bogus_extern_codegen_target], + ) + + @pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") def test_extern_gcc_consts(): @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 22b3b8cb30638..046b2c7e541de 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -18,19 +18,25 @@ import sys import numpy as np import pytest +import logging +import tvm import tvm.testing from tvm import relay, IRModule from utils.external_codegen import ( + parametrize_external_codegen_checks, set_external_func_attr, check_aot_executor_result, check_graph_executor_result, + check_vm_result, ) +logging.basicConfig(level=logging.INFO) -@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) -def test_tir_external_generation(check_result): + +@parametrize_external_codegen_checks +def test_tir_external_generation_inline_without_target_instance(check_result): shape = (8,) x_data = np.random.randint(255, size=shape).astype("float32") y_data = np.random.randint(255, size=shape).astype("float32") @@ -50,6 +56,49 @@ def test_tir_external_generation(check_result): check_result(func, inputs, (8,), x_data - y_data) +# TODO(mbs): The check_aot_executor_result does not support list-of-targets, mostly because +# tvm.testing.aot.compile_and_run requires the target to be a kind name string, and +# tvm.testing.aot.compile_models requires a single Target object. However, code outside of +# tvm.testing.aot is ready for this more general form. +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_tir_external_generation_outline_with_target_instance(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + # Compile with an instance of the hooked target kind to demonstrate plumbing target attributes + # into custom passes. + host_target = tvm.target.Target("llvm") + generic_target = tvm.target.Target("llvm", host=host_target) + extern_codegen_target = tvm.target.Target( + "example_target_hook -example_attribute=42", host=host_target + ) + mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(8), float32], %y: Tensor[(8), float32]) -> Tensor[(8), float32] { + @replace_add_with_subtract(%x, %y) * 2.0f + } + + def @replace_add_with_subtract(%x: Tensor[(8), float32], %y: Tensor[(8), float32], + Inline=1, + Primitive=1, + Compiler="example_target_hook", + global_symbol="replace_add_with_subtract") -> Tensor[(8), float32] { + %x + %y // will be rewritten to TIR implementing %x - %y - 42.0f by custom pass + } + """ + ) + + check_result( + mod, + inputs, + (8,), + (x_data - y_data - 42.0) * 2.0, + target=[generic_target, extern_codegen_target], + ) + + @pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) def test_runtime_module_generation(check_result): shape = (8,) diff --git a/tests/python/relay/utils/external_codegen.py b/tests/python/relay/utils/external_codegen.py index 6d3d917ff5a23..8e5ab803de7a6 100644 --- a/tests/python/relay/utils/external_codegen.py +++ b/tests/python/relay/utils/external_codegen.py @@ -22,7 +22,7 @@ import pytest import tvm -from tvm import relay, runtime +from tvm import relay, runtime, testing from tvm.contrib import utils From 24b93f56fdbb723cc0f631ce4da0e27d7fb212b1 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 1 Jun 2022 22:59:30 +0300 Subject: [PATCH 010/181] [VM] check DLManagedTensor for conditions to construct NDArray (#11504) * check DLManagedTensor for contiguous and alignment to construct correct NDArray * correction from the reviewer * update error description for incontiguous DLTensors * small update Co-authored-by: Valery Chernov --- src/runtime/ndarray.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 0b4a9dfdd9e91..c7bfefa9a8e73 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -206,8 +206,7 @@ NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional< } NDArray NDArray::FromExternalDLTensor(const DLTensor& dl_tensor) { - ICHECK(::tvm::runtime::IsContiguous(dl_tensor)) - << "External DLTensor is not contiguous. It does not support for now"; + ICHECK(::tvm::runtime::IsContiguous(dl_tensor)) << "External DLTensor must be contiguous."; ICHECK(IsAligned(dl_tensor)) << "Data in DLTensor is not aligned as required by NDArray"; NDArray::Container* data = new NDArray::Container(); @@ -224,7 +223,7 @@ NDArray NDArray::FromExternalDLTensor(const DLTensor& dl_tensor) { NDArray NDArray::NewFromDLTensor(DLTensor* tensor, const Device& dev) { ICHECK(::tvm::runtime::IsContiguous(*tensor)) - << "DLTensor is not contiguous. It does not support for now"; + << "DLTensor is not contiguous. Copying from non-contiguous data is currently not supported"; std::vector shape; for (int64_t i = 0; i < tensor->ndim; i++) { shape.push_back(tensor->shape[i]); @@ -240,6 +239,9 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { data->SetDeleter(Internal::DLPackDeleter); // fill up content. data->manager_ctx = tensor; + ICHECK(::tvm::runtime::IsContiguous(tensor->dl_tensor)) << "DLManagedTensor must be contiguous."; + ICHECK(IsAligned(tensor->dl_tensor)) + << "Data in DLManagedTensor is not aligned as required by NDArray"; data->dl_tensor = tensor->dl_tensor; // update shape_ std::vector shape; From b9890dbbebeff95202a7dc65cbce3e808869cd33 Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Wed, 1 Jun 2022 13:05:30 -0700 Subject: [PATCH 011/181] [skip ci][ci][docs] Add CI infra docs (#11403) * [skip ci][ci][docs] Add CI infra docs This adds some documentation around CI infra and pointers to the guides to run a deploy. * Address comments Co-authored-by: driazati --- docs/contribute/ci.rst | 108 ---------------------- jenkins/README.md | 203 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 201 insertions(+), 110 deletions(-) diff --git a/docs/contribute/ci.rst b/docs/contribute/ci.rst index d40e4d5ab74b2..0cc1bf9dd992b 100644 --- a/docs/contribute/ci.rst +++ b/docs/contribute/ci.rst @@ -63,114 +63,6 @@ Reproduce Failures Most TVM Python tests run under |pytest|_ and can be run as described in :ref:`pr-testing`. -Keeping CI Green -**************** - -Developers rely on the TVM CI to get signal on their PRs before merging. -Occasionally breakages slip through and break ``main``, which in turn causes -the same error to show up on an PR that is based on the broken commit(s). Broken -commits can be identified `through GitHub `_ -via the commit status icon or via `Jenkins `_. -In these situations it is possible to either revert the offending commit or -submit a forward fix to address the issue. It is up to the committer and commit -author which option to choose, keeping in mind that a broken CI affects all TVM -developers and should be fixed as soon as possible. - -Skip CI for Reverts -------------------- - -For reverts and trivial forward fixes, adding ``[skip ci]`` to the revert's -PR title will cause CI to shortcut and only run lint. Committers should -take care that they only merge CI-skipped PRs to fix a failure on ``main`` and -not in cases where the submitter wants to shortcut CI to merge a change faster. -The PR title is checked when the build is first run (specifically during the lint -step, so changes after that has run do not affect CI and will require the job to -be re-triggered by another ``git push``). - -.. code:: bash - - # Revert HEAD commit, make sure to insert '[skip ci]' at the beginning of - # the commit subject - git revert HEAD - git checkout -b my_fix - # After you have pushed your branch, create a PR as usual. - git push my_repo - # Example: Skip CI on a branch with an existing PR - # Adding this commit to an existing branch will cause a new CI run where - # Jenkins is skipped - git commit --allow-empty --message "[skip ci] Trigger skipped CI" - git push my_repo - -Handling Flaky Failures -*********************** - -.. https://stackoverflow.com/questions/4743845/format-text-in-a-link-in-restructuredtext/4836544#4836544 -.. |pytest's @xfail decorator| replace:: pytest's ``@xfail`` decorator -.. _pytest's @xfail decorator: https://docs.pytest.org/en/6.2.x/skipping.html#xfail-mark-test-functions-as-expected-to-fail -.. |strict=True| replace:: ``strict=True`` -.. _strict=True: https://docs.pytest.org/en/6.2.x/skipping.html#strict-parameter - -If you notice a failure on your PR that seems unrelated to your change, you should -search `recent GitHub issues related to flaky tests `_ and -`file a new issue `_ -if you don't see any reports of the failure. If a certain test or class of tests affects -several PRs or commits on ``main`` with flaky failures, the test should be disabled via -|pytest's @xfail decorator|_ with |strict=True|_ and the relevant issue linked in the -disabling PR. - -.. code:: python - - @pytest.mark.xfail(strict=False, reason="Flaky test: https://github.com/apache/tvm/issues/1234") - def test_something_flaky(): - pass - -``ci-docker-staging`` -********************* - -The `ci-docker-staging `_ -branch is used to test updates to Docker images and ``Jenkinsfile`` changes. When -running a build for a normal PR from a forked repository, Jenkins uses the code -from the PR except for the ``Jenkinsfile`` itself, which comes from the base branch. -When branches are built, the ``Jenkinsfile`` in the branch is used, so a committer -with write access must push PRs to a branch in apache/tvm to properly test -``Jenkinsfile`` changes. If your PR makes changes to the ``Jenkinsfile``, make sure -to @ a `committer `_ -and ask them to push your PR as a branch to test the changes. - -.. _docker_images: - -Docker Images -************* - -.. |top_of_the_Jenkinsfile| replace:: top of the ``Jenkinsfile`` -.. _top_of_the_Jenkinsfile: https://github.com/apache/tvm/blob/7481a297740f073b193a3f09b3e27f056e8c7f2e/Jenkinsfile#L48-L54 - -Each CI job runs most of its work inside a Docker container, built from files -in the `docker/ `_ folder. These -files are built nightly in Jenkins via the `docker-images-ci `_ job. -The images for these containers are hosted in the `tlcpack Docker Hub `_ -and referenced at the |top_of_the_Jenkinsfile|_. These can be inspected and run -locally via standard Docker commands. - -.. code:: bash - - # Beware: CI images can be several GB in size - # Get a bare docker shell in the ci-gpu container - docker run -it tlcpack/ci-gpu:v0.78 /bin/bash - -``docker/bash.sh`` will automatically grab the latest image from the ``Jenkinsfile`` -and help in mounting your current directory. - -.. code:: bash - - # Run the ci_cpu image specified in Jenkinsfile - cd tvm - bash docker/bash.sh ci_cpu - # the tvm directory is automatically mounted - # example: build tvm (note: this will overrwrite build/) - $ ./tests/scripts/task_config_build_cpu.sh - $ ./tests/scripts/task_build.sh build -j32 - Reporting Issues **************** diff --git a/jenkins/README.md b/jenkins/README.md index 454664b40c643..f2f695f9fc5da 100644 --- a/jenkins/README.md +++ b/jenkins/README.md @@ -15,14 +15,213 @@ +# TVM CI + +TVM runs CI jobs on every commit to an open pull request and to branches in the apache/tvm repo (such as `main`). These jobs are essential to keeping the TVM project in a healthy state and preventing breakages. Jenkins does most of the work in running the TVM tests, though some smaller jobs are also run on GitHub Actions. + +## GitHub Actions + +GitHub Actions is used to run Windows jobs, MacOS jobs, and various on-GitHub automations. These are defined in [`.github/workflows`](../.github/workflows/). These automations include bots to: +* [cc people based on subscribed teams/topics](https://github.com/apache/tvm/issues/10317) +* [allow non-committers to merge approved / CI passing PRs](https://discuss.tvm.apache.org/t/rfc-allow-merging-via-pr-comments/12220) +* [add cc-ed people as reviewers on GitHub](https://discuss.tvm.apache.org/t/rfc-remove-codeowners/12095) +* [ping languishing PRs after no activity for a week (currently opt-in only)](https://github.com/apache/tvm/issues/9983) +* [push a `last-successful` branch to GitHub with the last `main` commit that passed CI](https://github.com/apache/tvm/tree/last-successful) + +https://github.com/apache/tvm/actions has the logs for each of these workflows. Note that when debugging these workflows changes from PRs from forked repositories won't be reflected in the PR. These should be tested in the forked repository first and linked in the PR body. + + +## Keeping CI Green + +Developers rely on the TVM CI to get signal on their PRs before merging. +Occasionally breakages slip through and break `main`, which in turn causes +the same error to show up on an PR that is based on the broken commit(s). Broken +commits can be identified [through GitHub](https://github.com/apache/tvm/commits/main>) +via the commit status icon or via [Jenkins](https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/activity?branch=main>). +In these situations it is possible to either revert the offending commit or +submit a forward fix to address the issue. It is up to the committer and commit +author which option to choose, keeping in mind that a broken CI affects all TVM +developers and should be fixed as soon as possible. + +Some tests are also flaky and fail for reasons unrelated to the PR. The [CI monitoring rotation](https://github.com/apache/tvm/wiki/CI-Monitoring-Runbook) watches for these failures and disables tests as necessary. It is the responsibility of those who wrote the test to ultimately fix and re-enable the test. + + +## Dealing with Flakiness + +If you notice a failure on your PR that seems unrelated to your change, you should +search [recent GitHub issues related to flaky tests](https://github.com/apache/tvm/issues?q=is%3Aissue+%5BCI+Problem%5D+Flaky+>) and +[file a new issue](https://github.com/apache/tvm/issues/new?assignees=&labels=&template=ci-problem.md&title=%5BCI+Problem%5D+>) +if you don't see any reports of the failure. If a certain test or class of tests affects +several PRs or commits on `main` with flaky failures, the test should be disabled via +[pytest's @xfail decorator](https://docs.pytest.org/en/6.2.x/skipping.html#xfail-mark-test-functions-as-expected-to-fail) with [`strict=False`](https://docs.pytest.org/en/6.2.x/skipping.html#strict-parameter) and the relevant issue linked in the +disabling PR. + +```python +@pytest.mark.xfail(strict=False, reason="Flaky test: https://github.com/apache/tvm/issues/1234") + def test_something_flaky(): + pass +``` + +Then submit a PR as usual + +```bash +git add +git commit -m'[skip ci][ci] Disable flaky test: `` + +See # +' +gh pr create +``` + +## Skipping CI + +For reverts and trivial forward fixes, adding `[skip ci]` to the revert's +PR title will cause CI to shortcut and only run lint. Committers should +take care that they only merge CI-skipped PRs to fix a failure on `main` and +not in cases where the submitter wants to shortcut CI to merge a change faster. +The PR title is checked when the build is first run (specifically during the lint +step, so changes after that has run do not affect CI and will require the job to +be re-triggered by another `git push`). + +```bash +# Revert HEAD commit, make sure to insert '[skip ci]' at the beginning of +# the commit subject +git revert HEAD +git checkout -b my_fix +# After you have pushed your branch, create a PR as usual. +git push my_repo +# Example: Skip CI on a branch with an existing PR +# Adding this commit to an existing branch will cause a new CI run where +# Jenkins is skipped +git commit --allow-empty --message "[skip ci] Trigger skipped CI" +git push my_repo +``` + +## Docker Images + +Each CI job runs most of its work inside a Docker container, built from files +in the [`docker/`](../docker) folder. These +files are built nightly in Jenkins via the [docker-images-ci](https://ci.tlcpack.ai/job/docker-images-ci/>) job. +The images for these containers are hosted in the [tlcpack Docker Hub](https://hub.docker.com/u/tlcpack>) +and referenced in the [`Jenkinsfile.j2`](Jenkinsfile.j2). These can be inspected and run +locally via standard Docker commands. + +### `ci-docker-staging` + +The [ci-docker-staging](https://github.com/apache/tvm/tree/ci-docker-staging>) +branch is used to test updates to Docker images and `Jenkinsfile` changes. When +running a build for a normal PR from a forked repository, Jenkins uses the code +from the PR except for the `Jenkinsfile` itself, which comes from the base branch. +When branches are built, the `Jenkinsfile` in the branch is used, so a committer +with write access must push PRs to a branch in apache/tvm to properly test +`Jenkinsfile` changes. If your PR makes changes to the `Jenkinsfile`, make sure +to @ a [committer](../CONTRIBUTORS.md>) +and ask them to push your PR as a branch to test the changes. + # Jenkins CI +TVM uses Jenkins for running Linux continuous integration (CI) tests on +[branches](https://ci.tlcpack.ai/job/tvm/) and +[pull requests](https://ci.tlcpack.ai/job/tvm/view/change-requests/) through a +build configuration specified in a [`Jenkinsfile`](../Jenkinsfile). +Other jobs run in GitHub Actions for Windows and MacOS jobs. + +## `Jenkinsfile` + The template files in this directory are used to generate the [`Jenkinsfile`](../Jenkinsfile) used by Jenkins to run CI jobs for each commit to PRs and branches. To regenerate the `Jenkinsfile`, run ```bash -pip install -r jenkins/requirements.txt -python jenkins/generate.py +python3 -mvenv _venv +_venv/bin/pip3 install -r jenkins/requirements.txt +_venv/bin/python3 jenkins/generate.py ``` +# Infrastructure + +Jenkins runs in AWS on an EC2 instance fronted by an ELB which makes it available at https://ci.tlcpack.ai. These definitions are declared via Terraform in the [tlc-pack/ci-terraform](https://github.com/tlc-pack/ci-terraform) repository. The Terraform code references custom AMIs built in [tlc-pack/ci-packer](https://github.com/tlc-pack/ci-packer). [tlc-pack/ci](https://github.com/tlc-pack/ci) contains Ansible scripts to deploy the Jenkins head node and set it up to interact with AWS. + +The Jenkins head node has a number of autoscaling groups with labels that are used to run jobs (e.g. `CPU`, `GPU` or `ARM`) via the [EC2 Fleet](https://plugins.jenkins.io/ec2-fleet/) plugin. + +## Deploying + +Deploying Jenkins can disrupt developers so it must be done with care. Jobs that are in-flight will be cancelled and must be manually restarted. Follow the instructions [here](https://github.com/tlc-pack/ci/issues/10) to run a deploy. + +## Monitoring + +Dashboards of CI data can be found: +* within Jenkins at https://ci.tlcpack.ai/monitoring (HTTP / JVM stats) +* at https://monitoring.tlcpack.ai (job status, worker status) + +## CI Diagram + +This details the individual parts that interact in TVM's CI. For details on operations, see https://github.com/tlc-pack/ci. + +```mermaid +graph TD + Commit --> GitHub + GitHub --> |`push` webhook| WebhookServer(Webhook Server) + JobExecutor(Job Executor) + WebhookServer --> JobExecutor + JobExecutor --> EC2Fleet(EC2 Fleet Plugin) + EC2Fleet --> |capacity request| EC2(EC2 Autoscaler) + JobExecutor --> WorkerEC2Instance + Docker --> |build cache, artifacts| S3 + WorkerEC2Instance --> Docker + Docker --> |docker pull| G(Docker Hub) + Docker --> |docker push / pull| ECR + Docker --> |Execute jobs| CIScripts(CI Scripts) + RepoCITerraform(ci-terraform repo) --> |terraform| ECR + RepoCITerraform(ci-terraform repo) --> |terraform| EC2 + RepoCITerraform(ci-terraform repo) --> |terraform| S3 + RepoCI(ci repo) --> |configuration via Ansible| WorkerEC2Instance + RepoCIPacker(ci-packer) --> |AMIs| EC2 + Monitoring_Scrapers(Jenkins Scraper) --> Monitoring_DB(Postrgres) + Grafana --> Monitoring_DB + GitHub --> Windows + GitHub --> MacOS + + Developers --> |check PR status|JenkinsUI(Jenkins Web UI) + Monitoring_Scrapers --> |fetch job data| JenkinsUI + Developers --> |git push| Commit + Developers --> |create PR| GitHub + + subgraph Jenkins Head Node + WebhookServer + JobExecutor + EC2Fleet + JenkinsUI + end + + subgraph GitHub Actions + Windows + MacOS + end + + subgraph Configuration / Terraform + RepoCITerraform + RepoCI + RepoCIPacker + end + + subgraph Monitoring + Monitoring_DB + Grafana + Monitoring_Scrapers + end + + subgraph AWS + subgraph Jenkins Workers + WorkerEC2Instance(Worker EC2 Instance) + subgraph "Worker EC2 Instance" + Docker + CIScripts + end + end + EC2 + ECR + S3 + end + +``` From a1d95ec1ea30ac70e544a3cf10c839e228d407bf Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Wed, 1 Jun 2022 13:07:36 -0700 Subject: [PATCH 012/181] [ci] Add conditionals for non-Python tests (#11438) These don't get sharded in any way so there's no point in running them multiple times. cc Mousius areusch --- Jenkinsfile | 7 +------ jenkins/Test.groovy.j2 | 4 ++++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 44389ba767dc7..b9175f06afdc5 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,7 +45,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-05-27T14:45:11.226042 +// Generated at 2022-05-31T16:54:56.997402 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> @@ -1268,7 +1268,6 @@ def shard_run_python_i386_1_of_5() { script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", label: 'Run i386 integration tests', ) - fsim_test(ci_i386) }) } } finally { @@ -1360,7 +1359,6 @@ def shard_run_python_i386_3_of_5() { script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", label: 'Run i386 integration tests', ) - fsim_test(ci_i386) }) } } finally { @@ -1406,7 +1404,6 @@ def shard_run_python_i386_4_of_5() { script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", label: 'Run i386 integration tests', ) - fsim_test(ci_i386) }) } } finally { @@ -1452,7 +1449,6 @@ def shard_run_python_i386_5_of_5() { script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", label: 'Run i386 integration tests', ) - fsim_test(ci_i386) }) } } finally { @@ -2476,7 +2472,6 @@ def shard_run_topi_aarch64_2_of_2() { ) ci_setup(ci_arm) - cpp_unittest(ci_arm) sh ( script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh", label: 'Run test_arm_compute_lib test', diff --git a/jenkins/Test.groovy.j2 b/jenkins/Test.groovy.j2 index 9f949ae717c2a..d86575c247c75 100644 --- a/jenkins/Test.groovy.j2 +++ b/jenkins/Test.groovy.j2 @@ -74,7 +74,9 @@ script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", label: 'Run i386 integration tests', ) + {% if shard_index == 2 or num_shards < 2 %} fsim_test(ci_i386) + {% endif %} {% endcall %} {% call(shard_index, num_shards) m.sharded_test_step( name="test: Hexagon", @@ -156,7 +158,9 @@ ) %} {{ m.download_artifacts(tag='arm', filenames=tvm_multilib) }} ci_setup(ci_arm) + {% if shard_index == 1 %} cpp_unittest(ci_arm) + {% endif %} sh ( script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh", label: 'Run test_arm_compute_lib test', From e84f163f573c07bb9f41209b8f722c76a92ae65d Mon Sep 17 00:00:00 2001 From: Sergey <88086617+shtinsa@users.noreply.github.com> Date: Wed, 1 Jun 2022 23:13:41 +0300 Subject: [PATCH 013/181] [TE] Optimized version of concatenation layer (#11341) * [TE] Optimized version of concatenation layer 1. Concat implemented using extern_op 2. New tests added. 3. Workaround to allow inline extern_op-s with other layers. * *test fix * test_any.py fix. * test_forward.py from tensorflow fix. * lint fix. * Fixes after code review. * New comment added. * Lint fix. * Another lint fix. * Comments added. * rebase issue fix. * Restored previous state. * Update after code review. * After code review changes. * lint review. * Change strategy for cuda to fix tests. * Rebase to main * Comments changes after review. * Some more comments fixes. * One more error fix in comments. * restart build --- python/tvm/relay/op/_transform.py | 7 +- python/tvm/relay/op/strategy/cuda.py | 14 ++- python/tvm/relay/op/strategy/generic.py | 21 ++++ python/tvm/relay/op/strategy/x86.py | 40 +++++-- python/tvm/topi/x86/__init__.py | 1 + python/tvm/topi/x86/concat.py | 109 ++++++++++++++++++ python/tvm/topi/x86/injective.py | 42 ++++++- src/relay/op/tensor/transform.cc | 1 - src/te/schedule/schedule_dataflow_rewrite.cc | 30 ++++- tests/python/relay/test_op_level1.py | 97 ++++++++++++++++ .../test_micro_model_library_format.py | 27 +++-- 11 files changed, 359 insertions(+), 30 deletions(-) create mode 100644 python/tvm/topi/x86/concat.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0338035329fcf..d87ee266f01df 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -68,7 +68,12 @@ # concatenate -_reg.register_schedule("concatenate", strategy.schedule_concatenate) +@_reg.register_compute("concatenate") +def compute_concat(attrs, inputs, output_type): + return [topi.concatenate(inputs, attrs.axis)] + + +_reg.register_strategy("concatenate", strategy.concatenate_strategy) # sliding_window @_reg.register_compute("sliding_window") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 59971d4e206f5..4a7cff5f3f33c 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -42,11 +42,15 @@ def schedule_reduce_cuda(attrs, outs, target): return topi.cuda.schedule_reduce(outs) -@schedule_concatenate.register(["cuda", "gpu"]) -def schedule_concatenate_cuda(attrs, outs, target): - """schedule concatenate for cuda""" - with target: - return topi.cuda.schedule_injective(outs) +@concatenate_strategy.register(["cuda", "gpu"]) +def concatenate_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.cuda.schedule_injective), + name="concatenate.cuda", + ) + return strategy @schedule_pool.register(["cuda", "gpu"]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index fa62af5f9fed2..2bb009dbc8f71 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1781,6 +1781,15 @@ def _compute_scanop(attrs, inputs, _): return _compute_scanop +def wrap_compute_concat(topi_compute): + """Wrap concatenate topi compute""" + + def _compute_concat(attrs, inputs, _): + return [topi_compute(inputs, attrs.axis)] + + return _compute_concat + + @override_native_generic_func("cumsum_strategy") def cumsum_strategy(attrs, inputs, out_type, target): """cumsum generic strategy""" @@ -1793,6 +1802,18 @@ def cumsum_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("concat_strategy") +def concatenate_strategy(attrs, inputs, out_type, target): + """concatenate generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.concatenate), + wrap_topi_schedule(topi.generic.schedule_injective), + name="concatenate", + ) + return strategy + + @override_native_generic_func("cumprod_strategy") def cumprod_strategy(attrs, inputs, out_type, target): """cumprod generic strategy""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 0beb99e4f7dbf..59a57fd233f56 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -19,7 +19,7 @@ import logging import re -from tvm import topi +from tvm import topi, tir from tvm.topi.x86.utils import target_has_vnni from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition @@ -48,13 +48,6 @@ def schedule_reduce_cpu(attrs, outs, target): return topi.x86.schedule_reduce(outs) -@schedule_concatenate.register("cpu") -def schedule_concatenate_cpu(attrs, outs, target): - """schedule concatenate op for x86""" - with target: - return topi.x86.schedule_concatenate(outs) - - @schedule_pool.register("cpu") def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" @@ -741,3 +734,34 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) ) return strategy + + +@concatenate_strategy.register(["cpu"]) +def concatenate_strategy_cpu(attrs, inputs, out_type, target): + """concatenate x86 strategy""" + strategy = _op.OpStrategy() + use_only_old_concat = False + for inpt in inputs: + shape = inpt.shape + for i in shape: + if not isinstance(i, tir.expr.IntImm): + use_only_old_concat = True + break + if use_only_old_concat: + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + else: + strategy.add_implementation( + wrap_compute_concat(topi.x86.concatenate), + wrap_topi_schedule(topi.x86.schedule_concatenate_cpu), + name="concatenate.cpu", + ) + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + return strategy diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 34a5e0362d871..d075090f01eac 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -43,3 +43,4 @@ from .scatter import * from .group_conv2d import * from .math_alter_op import * +from .concat import * diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py new file mode 100644 index 0000000000000..5cb3cd3f57d50 --- /dev/null +++ b/python/tvm/topi/x86/concat.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"concatenate related operators" +from typing import Optional +import tvm +from tvm import te +import numpy as np +from ..utils import get_const_int, const_vector + + +def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): + """Join a sequence of arrays along an existing axis. Optimized for CPU exeution. + + Parameters + ---------- + data : tuple of tvm.te.Tensor + The arrays to concatenate + + axis : int, optional + The axis along which the arrays will be joined. Default is 0. + + Returns + ------- + ret : tvm.te.Tensor + """ + + def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): + """Custom conactenation execution.""" + i_b = tvm.tir.ir_builder.create() + data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = i_b.buffer_ptr(out_buf) + outers = i_b.buffer_ptr(in_outers_tensor) + cumsum = i_b.buffer_ptr(in_cumsum_tensor) + for i in range(len(data)): + with i_b.for_range(0, outers[i], name="j") as j: + out_buf[cumsum[i] + j] = data_bufs1[i][j] + return i_b.get() + + def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer): + """Common case of conactenation execution.""" + i_b = tvm.tir.ir_builder.create() + data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = i_b.buffer_ptr(out_buf) + outers = i_b.buffer_ptr(in_outers_tensor) + cumsum = i_b.buffer_ptr(in_cumsum_tensor) + if inner > 1: + with i_b.for_range(0, inner, name="inn", kind="parallel") as inn: + pos = inn * outer + for i in range(len(data)): + offset = inn * outers[i] + with i_b.for_range(0, outers[i], name="j") as j: + out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j] + else: + for i in range(len(data)): + with i_b.for_range(0, outers[i], name="j", kind="parallel") as j: + out_buf[cumsum[i] + j] = data_bufs1[i][j] + return i_b.get() + + if axis < 0: + axis += len(data[0].shape) + concat_axis_sizes = [int(t.shape[axis]) for t in data] + join_size = int(np.sum(concat_axis_sizes)) + in_outers = [int(np.prod(i.shape[axis:])) for i in data] + in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]] + dtype = data[0].dtype + out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :] + in_outers_tensor = const_vector(in_outers) + in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") + right_val = np.prod(out_shape[axis:]) + left_val = np.prod(out_shape[:axis]) + + if ( + len(data[0].shape) == 1 + or right_val == 1 + or (left_val == 1 and axis == len(data[0].shape) - 1) + or (left_val == 1 and right_val == 1) + ): + # badly parallelized case + return te.extern( + [out_shape], + list(data) + [in_outers_tensor, in_cumsum_tensor], + lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]), + dtype=dtype, + name="concatenate_ext", + ) + + inner = get_const_int(int(left_val)) + outer = get_const_int(int(right_val)) + return te.extern( + [out_shape], + list(data) + [in_outers_tensor, in_cumsum_tensor], + lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer), + dtype=dtype, + name="concatenate_ext", + ) diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 6492b78d6037a..78893397ba31d 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -17,20 +17,22 @@ # pylint: disable=invalid-name """x86 declaration and schedules.""" from tvm import te +from tvm.topi import tag from tvm.tir import IntImm +from tvm.topi.generic.injective import ( + schedule_injective_from_existing as schedule_injective_for_concat, +) from ..utils import is_empty_shape def schedule_injective_from_existing(sch, out): """Schedule for injective op from existing schedule. - Parameters ---------- sch: Schedule The schedule to update. out: Tensor The tensor representing the injective op. - Returns ------- sch: Schedule @@ -61,13 +63,11 @@ def schedule_injective_from_existing(sch, out): def schedule_injective(outs): """X86 schedule for injective op. - Parameters ---------- outs: Array of Tensor The computation graph description of injective in the format of an array of tensors. - Returns ------- sch: Schedule @@ -85,13 +85,11 @@ def schedule_injective(outs): def schedule_concatenate(outs): """X86 schedule for concatenate op. - Parameters ---------- outs: Array of Tensor The computation graph description of injective in the format of an array of tensors. - Returns ------- sch: Schedule @@ -132,5 +130,37 @@ def vectorize(sch, tensor, vectorize_limit): return s +def schedule_concatenate_cpu(outs): + """X86 schedule for concatenate op. + Parameters + ---------- + outs: Array of Tensor + The computation graph description in the format + of an array of tensors. + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_for_concat(s, op.output(0)) + + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + + return s + + schedule_elemwise = schedule_injective schedule_broadcast = schedule_injective diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e888eccc2b1c7..57bf9f36def93 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -346,7 +346,6 @@ RELAY_REGISTER_OP("concatenate") .set_support_level(1) .add_type_rel("Concatenate", ConcatenateRel) .set_attr("FInferCorrectLayout", ConcatenateLayout) - .set_attr("FTVMCompute", ConcatenateCompute) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 2b30055c4f424..a8363fd084cd2 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -511,6 +511,29 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { std::vector changed(sch->stages.size(), false); std::vector new_hybrid_body(sch->stages.size()); std::vector hybrid_changed(sch->stages.size(), false); + // (sshtin): this workaround allows to inline extern ops into their consumer. + // All inputs for extern op should not be inlined because inlining may happen + // before TE generation for particular extern op. That may lead to + // crash during lowering or building stages. + // The problem description: + // In case of operations fusing, arguments inlining + // prevents creation of ProducerNode for extern operation. + // Instead of the creation it is supposed to use operation argument as inlined buffer + // but extern_op TIR generation can be peformed after inlining procedure so + // newly generated TIR does not have reference to input data at all. + std::unordered_map ext_ops; + for (size_t i = 0; i < sch->stages.size(); i++) { + Stage stage = sch->stages[i]; + auto ext_op = stage->op.as(); + if (ext_op) { + auto inps = ext_op->InputTensors(); + for (size_t ii = 0; ii < inps.size(); ++ii) { + if (ext_ops.find(inps[ii]->op) == ext_ops.end()) { + ext_ops[inps[ii]->op] = stage->op; + } + } + } + } // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; @@ -525,8 +548,13 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { for (auto iv : compute->axis) { args.push_back(iv->var); } + if (ext_ops.find(stage->op) != ext_ops.end()) { + // sshtin: The extern op can try to get access to the input tensors as a raw data, + // that can lead to error in IR builder. + stage->attach_type = kGroupRoot; + continue; + } ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output"; - if (feature_extraction_mode && compute->attrs.count("const_matrix")) { // Use constant value to replace access of const matrices. // This produces wrong IR but is good enough for feature extraction purposes. diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 170850809ad54..f4afc9e90562c 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -431,6 +431,103 @@ def test_batch_norm(): ) +def do_concat_test(shapes, t_shape, dtype, axis, dev, target): + varsToConcat = [] + inputData = [] + pos = 0 + for s in shapes: + varsToConcat.append(relay.var("x{}".format(pos), shape=s)) + inputData.append(np.random.rand(*s).astype(dtype)) + pos += 1 + t = relay.var("z", shape=t_shape, dtype=dtype) + z = relay.concatenate(varsToConcat, axis=axis) + z = relay.add(z, t) + params = varsToConcat + params.append(t) + func = relay.Function(params, z) + t_data = np.random.uniform(low=-10, high=10, size=t_shape).astype(dtype) + ref_res = np.concatenate((tuple(inputData)), axis=axis) + t_data + mod = tvm.IRModule.from_expr(func) + + executor = relay.create_executor("graph", mod=mod, device=dev, target=target) + op_res1 = executor.evaluate()(*inputData, t_data) + + tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=0.000001) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + *inputData, t_data + ) + tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=0.000001) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate1(target, dev): + np.random.seed(471) + maxNumDimensions = 6 + shape = [4, 32, 16, 1, 31, 20, 21, 8, 28, 7] # just randomly selected 10 numbers + for dtype in ["float32"]: + for dimsNum in range(1, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(0, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + shp = tuple(shape[:dimsNum]) + finalSize += shape[(i % len(shape))] + shapes.append( + shp[:normalizedAxis] + + tuple([shape[(i % len(shape))]]) + + shp[normalizedAxis + 1 :] + ) + t_shape = shp[:normalizedAxis] + tuple([finalSize]) + shp[normalizedAxis + 1 :] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate2(target, dev): + # test to cover cases (1, .. , x, 1, .. , 1) + np.random.seed(13) + maxNumDimensions = 6 + shape = [8, 3, 25, 33, 12, 29, 5, 11, 29, 11] # just randomly selected 10 numbers + ind = 0 + for dtype in ["float32"]: + for dimsNum in range(2, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(-dimsNum + 1, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + axisVal = [1] * dimsNum + axisVal[axis] = shape[(ind % len(shape))] + ind += 1 + finalSize += axisVal[axis] + shapes.append(tuple(axisVal)) + temp = [1] * dimsNum + temp[axis] = finalSize + t_shape = tuple(temp) + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate3(target, dev): + np.random.seed(477) + for dtype in ["float32"]: + axis = -2 + ending = 1 + shapes = [[3, 2, 1, ending], [3, 2, 1, ending]] + t_shape = [3, 2, 2, ending] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + def test_batch_norm_fold_const(): axis = 1 dtype = "float32" diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index ad054479fd7b2..d707e6b4646b7 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -22,6 +22,7 @@ import numpy import pytest +import platform import tvm import tvm.relay @@ -418,14 +419,24 @@ def test_export_byoc_c_module(): with tf.extractfile("./metadata.json") as f: metadata = json.load(f) main_md = metadata["memory"]["functions"]["main"] - assert main_md == [ - { - "constants_size_bytes": 0, - "device": 1, - "io_size_bytes": 4800, - "workspace_size_bytes": 800, - } - ] + if platform.architecture()[0] == "64bit": + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 1264, + } + ] + else: + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 1248, + } + ] if __name__ == "__main__": From a329df40289eeca45163454bc1998a998d151d26 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang Date: Wed, 1 Jun 2022 13:25:05 -0700 Subject: [PATCH 014/181] [COMMUNITY] driazati -> Committer (#11525) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b0ad37c4e545c..cfd99ae73f653 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -62,6 +62,7 @@ We do encourage everyone to work anything they are interested in. - [Lily Orth-Smith](https://github.com/electriclilies): @electriclilies - relay - [Krzysztof Parzyszek](https://github.com/kparzysz-quic) (PMC): @kparzysz-quic - hexagon, llvm - [Andrew Reusch](https://github.com/areusch): (PMC) @areusch - runtime, microTVM +- [David Riazati](https://github.com/driazati): @driazati - ci, community - [Jared Roesch](https://github.com/jroesch) (PMC): @jroesch - relay - [Gustavo Romero](https://github.com/gromero): @gromero - microtvm, tvmc - [Giuseppe Rossini](https://github.com/giuseros): @giuseros - aot, arm From ce60bfa0ff014752e879ea5eae7ad87a9d32bc2c Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Wed, 1 Jun 2022 15:16:09 -0700 Subject: [PATCH 015/181] [ci] Add filter to teams (#11455) This improves the parsing to avoid issues like in #11454 commit-id:53a06ab3 Co-authored-by: driazati --- tests/python/ci/test_ci.py | 15 +++++++++++++++ tests/scripts/github_tag_teams.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/python/ci/test_ci.py b/tests/python/ci/test_ci.py index f5297c7ae7cce..042c109dd9d49 100644 --- a/tests/python/ci/test_ci.py +++ b/tests/python/ci/test_ci.py @@ -511,6 +511,7 @@ def run(type, data, check): """ comment2 = """ something @person4 + @person5 """ teams = { "data": { @@ -731,6 +732,20 @@ def run(type, data, check): check="Dry run, would have updated issues/1234 with {'body': '@person2 @SOME1-ONE-\\n\\ncc @person1'}", ) + run( + type="ISSUE", + data={ + "title": "[] A title", + "number": 1234, + "user": { + "login": "person5", + }, + "labels": [], + "body": "@person2 @SOME1-ONE-", + }, + check="No one to cc, exiting", + ) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/github_tag_teams.py b/tests/scripts/github_tag_teams.py index 96c22cf6a5db3..f040c1edc9780 100755 --- a/tests/scripts/github_tag_teams.py +++ b/tests/scripts/github_tag_teams.py @@ -122,7 +122,7 @@ def add_tag(tag, users): for tag in result: result[tag] = list(set(result[tag])) - return {k.lower(): v for k, v in result.items()} + return {k.lower(): v for k, v in result.items() if k.strip()} def tags_from_title(title: str) -> List[str]: From c6d7ecd0b5e71796c79b001f439322ae1d0ddbe0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 1 Jun 2022 23:57:33 -0700 Subject: [PATCH 016/181] [TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518) For 0-dimensional computation, `te.CreatePrimFunc` creates an opaque block with 0 block iters, which is mistakenly passed into TVMScript auto-completion that failed to add the root block properly. As an example, ```python >> from tvm import te >> a = te.placeholder((), name="a", dtype="int32") >> b = te.placeholder((), name="b", dtype="int32") >> c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c") >> f = te.create_prim_func([a, b, c]) >> print(f.body.block.reads) [a[], b[]] >> print(f.body.block.writes) [c[]] ``` This PR fixes this issue by enforcing the consistency that `te.CreatePrimFunc` always creates scheduleable blocks with at least 1 block iter: ```python @T.prim_func def func(a: T.Buffer[(), "int32"], b: T.Buffer[(), "int32"], c: T.Buffer[(), "int32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") with T.block("c"): vi = T.axis.spatial(1, 0) T.reads(a[()], b[()]) T.writes(c[()]) c[()] = a[()] + b[()] ``` --- .../task_scheduler/task_scheduler.cc | 2 ++ src/te/operation/create_primfunc.cc | 8 +++++- .../unittest/test_te_create_primfunc.py | 27 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 7485f4e076cdc..fd1d95cd1f19b 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -94,6 +94,8 @@ void SendToRunner(const Runner& runner, const TuneContext& context, PackedFunc l void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; + TVM_PY_LOG(INFO, this->logging_func) + << "Initializing Task #" << task_id << ": " << task->task_name; TVM_PY_LOG(INFO, task->logging_func) << "Initializing Task #" << task_id << ": " << task->task_name; CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 03ad551c68391..27cfdd605c5d4 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -264,6 +264,12 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, } // Set script_parsing_detect_access annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); + if (iter_vars.empty()) { + IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), IterVarType::kDataPar); + PrimExpr binding(0); + iter_vars.push_back(iter); + bindings.push_back(binding); + } // Step 6. Create Block and BlockRealize. return BlockRealize(/*iter_values=*/std::move(bindings), @@ -454,7 +460,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); const auto* complete = runtime::Registry::Get("script.Complete"); ICHECK(complete); - func = (*complete)(func, info.root_alloc); + func = (*complete)(std::move(func), info.root_alloc); return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); } diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 014ca71a8112a..5d9ad003b487c 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -524,6 +524,32 @@ def test_int64_indices(): assert loop.extent.dtype == "int64" +def test_zero_dim_add(): + def te_func(): + a = te.placeholder((), name="a", dtype="int32") + b = te.placeholder((), name="b", dtype="int32") + c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c") + return [a, b, c] + + @T.prim_func + def expected( + a: T.Buffer[(), "int32"], + b: T.Buffer[(), "int32"], + c: T.Buffer[(), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + with T.block("c"): + vi = T.axis.spatial(1, 0) + T.reads(a[()], b[()]) + T.writes(c[()]) + c[()] = a[()] + b[()] + + _check_workload(te_func, expected) + + if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -541,3 +567,4 @@ def test_int64_indices(): test_argmax_idx_val() test_argmax_val_idx() test_int64_indices() + test_zero_dim_add() From e60849c89934caa5709d4c42c5b7eda3f26c5e76 Mon Sep 17 00:00:00 2001 From: mhyang-pllab <75776819+mhyang-pllab@users.noreply.github.com> Date: Thu, 2 Jun 2022 15:53:15 +0800 Subject: [PATCH 017/181] Add ceil shape registration (#11533) --- python/tvm/relay/op/_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 23aff8bbb8b42..37cb263c489d3 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -306,3 +306,4 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("sigmoid", False, elemwise_shape_func) register_shape_func("tanh", False, elemwise_shape_func) register_shape_func("logical_not", False, elemwise_shape_func) +register_shape_func("ceil", False, elemwise_shape_func) From 4c513b9de3ebfdf4a1356f0daf7350e74ca74005 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 2 Jun 2022 01:44:05 -0700 Subject: [PATCH 018/181] [Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532) This PR fixes an existing bug in TIR lowering where the TIR below triggers an error: ```python @T.prim_func def func(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.serial(10): with T.block("b"): vi = T.axis.spatial(10, i) b[vi] = a[vi] tvm.build(func, target="llvm") ``` The error message is: ``` File "/root/Projects/tvm-dev/src/tir/transforms/flatten_buffer.cc", line 173 TVMError: --------------------------------------------------------------- An error occurred during the execution of TVM. For more information, please see: https://tvm.apache.org/docs/errors.html --------------------------------------------------------------- Check failed: store->buffer->dtype == DataType::Int(8) (bool vs. int8) : Expected int8 backing array for boolean tensor ``` This PR fixes this behavior. --- src/tir/transforms/flatten_buffer.cc | 18 ++++----- .../test_tir_transform_flatten_buffer.py | 37 ++++++++++++++++++- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index c7cc51d27113a..21de191db0091 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -53,9 +53,7 @@ class BufferFlattener : public StmtExprMutator { static PrimFunc Flatten(PrimFunc func) { Map preflattened_buffer_map = Merge(func->buffer_map, func->preflattened_buffer_map); - auto pass = BufferFlattener(func->buffer_map); - auto writer = func.CopyOnWrite(); writer->body = pass.VisitStmt(func->body); writer->preflattened_buffer_map = preflattened_buffer_map; @@ -137,7 +135,7 @@ class BufferFlattener : public StmtExprMutator { } else { PrimExpr expr = it->second; if (expr.dtype() != var.dtype()) { - expr = Cast(var.dtype(), std::move(expr)); + expr = tvm::cast(var.dtype(), std::move(expr)); } return expr; } @@ -164,33 +162,35 @@ class BufferFlattener : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + bool store_returns_bool = (op->value.dtype() == DataType::Bool()); + store = VisitBufferAccess(store); // Handle casts from the value's dtype to the dtype of the // backing array. // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. - if (store->value.dtype() == DataType::Bool()) { + if (store_returns_bool) { ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) << "Expected int8 backing array for boolean tensor"; auto writer = store.CopyOnWrite(); - writer->value = tir::Cast(DataType::Int(8), store->value); + writer->value = tvm::cast(DataType::Int(8), store->value); + return store; } - auto flattened_indices = store->buffer->ElemOffset(store->indices); - return VisitBufferAccess(std::move(store)); + return store; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { bool load_returns_bool = (op->dtype == DataType::Bool()); BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); load = VisitBufferAccess(load); - // Handle casts from dtype of the backing array to value's dtype. // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (load_returns_bool) { ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) << "Expected int8 backing array for boolean tensor"; - return tir::Cast(DataType::Bool(), load); + load.CopyOnWrite()->dtype = DataType::Int(8); + return tvm::cast(DataType::Bool(), load); } else { return std::move(load); } diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 65be43aba3212..f1a33a4fb203d 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te, tir from tvm.script import tir as T @@ -268,6 +268,33 @@ def annotated_loops(a: T.handle) -> None: A[i] = 0.0 +@T.prim_func +def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: + for i0 in T.serial(10): + with T.block("b"): + T.reads(a[i0]) + T.writes(b[i0]) + b[i0] = a[i0] + + +@T.prim_func +def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: + T.preflattened_buffer(a, [10], dtype="bool", data=a.data) + T.preflattened_buffer(b, [10], dtype="bool", data=b.data) + # body + for i0 in T.serial(10): + b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") + + +@T.prim_func +def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: + T.preflattened_buffer(a, [10], dtype="bool", data=a.data) + T.preflattened_buffer(b, [10], dtype="bool", data=b.data) + # body + for i0 in T.serial(10): + b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") + + def test_elementwise(): _check(compacted_elementwise_func, flattened_elementwise_func) @@ -319,6 +346,13 @@ def test_annotated_loops(): tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) +def test_boolean_handling(): + _check(boolean_handling_before, boolean_handling_after) + # mod = tvm.IRModule.from_expr(boolean_handling_before) + # mod = tvm.tir.transform.FlattenBuffer()(mod) + # print(mod.script()) + + if __name__ == "__main__": test_elementwise() test_gpu_workload() @@ -329,3 +363,4 @@ def test_annotated_loops(): test_strided_buffer() test_lower_te() test_annotated_loops() + test_boolean_handling() From bbca53d2ab354d7e8bed11fc9e1eae13fbee7730 Mon Sep 17 00:00:00 2001 From: apeskov Date: Thu, 2 Jun 2022 13:04:12 +0300 Subject: [PATCH 019/181] [DNNL] Add TensorRequisite concept (#11345) Allow to use DNNL runtime in multi instance mode. Thread safe execution of Run() method. Signed-off-by: Alexander Peskov --- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 1412 +++++------------ .../contrib/dnnl/dnnl_tensor_requisite.h | 720 +++++++++ src/runtime/contrib/dnnl/dnnl_utils.cc | 24 +- src/runtime/contrib/dnnl/dnnl_utils.h | 98 +- 4 files changed, 1239 insertions(+), 1015 deletions(-) create mode 100644 src/runtime/contrib/dnnl/dnnl_tensor_requisite.h diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index f6a1c3b790807..a2417f012ea42 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -32,7 +32,12 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" -#include "dnnl.hpp" + +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "dnnl_tensor_requisite.h" #include "dnnl_utils.h" namespace tvm { @@ -43,552 +48,82 @@ using namespace tvm::runtime; using namespace tvm::runtime::json; class DNNLJSONRuntime : public JSONRuntimeBase { - using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; - public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + : JSONRuntimeBase(symbol_name, graph_json, const_names), + next_unique_eid_offset_(data_entry_.size()), + run_arg_eid_(input_var_eid_) { + for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e)); + } - const char* type_key() const { return "dnnl_json"; } + const char* type_key() const override { return "dnnl_json"; } void Init(const Array& consts) override { - BuildEngine(); - ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; // Setup constants entries for weights. SetupConstants(consts); + BuildEngine(); } - void Run() override { - // Fill in the input buffers. - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto eid = EntryID(input_nodes_[i], 0); - size_t offset_in_bytes = - entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8); - size_t buffer_size = GetDataSize(*data_entry_[eid]); - write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); - } + /* Unused stub implementation */ + void Run() override { LOG(FATAL) << "Unreachable code"; } - // Invoke the engine through intepreting the stream. - for (size_t i = 0; i < net_.size(); ++i) { - net_.at(i).execute(stream_, net_args_.at(i)); - } - stream_.wait(); - - // Read output buffers. - for (size_t i = 0; i < outputs_.size(); ++i) { - auto eid = EntryID(outputs_[i]); - size_t offset_in_bytes = - entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8); - size_t buffer_size = GetDataSize(*data_entry_[eid]); - read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); + /* Thread safe implementation of Run. Keep runtime instance immutable */ + void Run(const TVMArgs& args) const { + auto arg_data_provider = makeIODataProvider(args); + auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider); + // Execute primitives one by one + for (const auto& act : net_) { + auto prim = std::get<0>(act); + auto arg_reqs = std::get<1>(act); + + // Find proper dnnl::memory buffers + std::unordered_map mem_args; + for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second); + + prim.execute(stream_, mem_args); } } - private: - tag layout2tag(std::string layout) { - static const std::map str2tag = {{"nc", tag::nc}, - {"cn", tag::cn}, - {"tn", tag::tn}, - {"nt", tag::nt}, - {"ncw", tag::ncw}, - {"nwc", tag::nwc}, - {"nchw", tag::nchw}, - {"nhwc", tag::nhwc}, - {"chwn", tag::chwn}, - {"ncdhw", tag::ncdhw}, - {"ndhwc", tag::ndhwc}, - {"oi", tag::oi}, - {"io", tag::io}, - {"oiw", tag::oiw}, - {"owi", tag::owi}, - {"wio", tag::wio}, - {"iwo", tag::iwo}, - {"oihw", tag::oihw}, - {"hwio", tag::hwio}, - {"ohwi", tag::ohwi}, - {"ihwo", tag::ihwo}, - {"iohw", tag::iohw}, - {"oidhw", tag::oidhw}, - {"dhwio", tag::dhwio}, - {"odhwi", tag::odhwi}, - {"iodhw", tag::iodhw}, - {"idhwo", tag::idhwo}, - {"goiw", tag::goiw}, - {"gowi", tag::gowi}, - {"wigo", tag::wigo}, - {"gohwi", tag::gohwi}, - {"goihw", tag::goihw}, - {"hwigo", tag::hwigo}, - {"giohw", tag::giohw}, - {"goidhw", tag::goidhw}, - {"giodhw", tag::giodhw}, - {"godhwi", tag::godhwi}, - {"dhwigo", tag::dhwigo}, - {"tnc", tag::tnc}, - {"ntc", tag::ntc}, - {"ldnc", tag::ldnc}, - {"ldigo", tag::ldigo}, - {"ldgoi", tag::ldgoi}, - {"ldio", tag::ldio}, - {"ldoi", tag::ldoi}, - {"ldgo", tag::ldgo}, - {"nCdhw16c", tag::nCdhw16c}, - {"nCdhw4c", tag::nCdhw4c}, - {"nCdhw8c", tag::nCdhw8c}, - {"nChw16c", tag::nChw16c}, - {"nChw4c", tag::nChw4c}, - {"nChw8c", tag::nChw8c}, - {"nCw16c", tag::nCw16c}, - {"nCw4c", tag::nCw4c}, - {"nCw8c", tag::nCw8c}, - {"NCw16n16c", tag::NCw16n16c}, - {"NChw16n16c", tag::NChw16n16c}, - {"NCdhw16n16c", tag::NCdhw16n16c}, - {"NCdhw32n32c", tag::NCdhw32n32c}, - {"NChw32n32c", tag::NChw32n32c}, - {"IOhw16i16o", tag::IOhw16i16o}, - {"OI16i16o", tag::OI16i16o}, - {"OI16i32o", tag::OI16i32o}, - {"OI16i64o", tag::OI16i64o}, - {"OI8i16o2i", tag::OI8i16o2i}, - {"OI8i32o2i", tag::OI8i32o2i}, - {"OI8i64o2i", tag::OI8i64o2i}, - {"OI4i16o4i", tag::OI4i16o4i}, - {"OI4i32o4i", tag::OI4i32o4i}, - {"OI4i64o4i", tag::OI4i64o4i}, - {"Ohwi32o", tag::Ohwi32o}, - {"IOdhw16i16o", tag::IOdhw16i16o}, - {"gIOhw16i16o", tag::gIOhw16i16o}, - {"gOhwi32o", tag::gOhwi32o}, - {"Goidhw16g", tag::Goidhw16g}, - {"IOw16o16i", tag::IOw16o16i}, - {"OIw16i16o", tag::OIw16i16o}, - {"OIw16i32o", tag::OIw16i32o}, - {"OIw16i64o", tag::OIw16i64o}, - {"IOw16i16o", tag::IOw16i16o}, - {"gIOw16i16o", tag::gIOw16i16o}, - {"OIw16o16i", tag::OIw16o16i}, - {"Oiw16o", tag::Oiw16o}, - {"OIw4i16o4i", tag::OIw4i16o4i}, - {"OIw4i32o4i", tag::OIw4i32o4i}, - {"OIw4i64o4i", tag::OIw4i64o4i}, - {"OIw2i8o4i", tag::OIw2i8o4i}, - {"OIw4i4o", tag::OIw4i4o}, - {"OIw4o4i", tag::OIw4o4i}, - {"Oiw4o", tag::Oiw4o}, - {"OIw8i16o2i", tag::OIw8i16o2i}, - {"OIw8i32o2i", tag::OIw8i32o2i}, - {"OIw8i64o2i", tag::OIw8i64o2i}, - {"OIw8i8o", tag::OIw8i8o}, - {"OIw8o16i2o", tag::OIw8o16i2o}, - {"OIw8o8i", tag::OIw8o8i}, - {"OIw8o4i", tag::OIw8o4i}, - {"OIw16i16o4i", tag::OIw16i16o4i}, - {"OIw16i32o4i", tag::OIw16i32o4i}, - {"OIw16i48o4i", tag::OIw16i48o4i}, - {"OIw16i64o4i", tag::OIw16i64o4i}, - {"OIw16i16o2i", tag::OIw16i16o2i}, - {"OIw16i32o2i", tag::OIw16i32o2i}, - {"OIw16i48o2i", tag::OIw16i48o2i}, - {"OIw16i64o2i", tag::OIw16i64o2i}, - {"OIw16o16i2o", tag::OIw16o16i2o}, - {"Owi16o", tag::Owi16o}, - {"OwI16o2i", tag::OwI16o2i}, - {"Owi4o", tag::Owi4o}, - {"Owi8o", tag::Owi8o}, - {"IOhw16o16i", tag::IOhw16o16i}, - {"Ohwi16o", tag::Ohwi16o}, - {"OhwI16o2i", tag::OhwI16o2i}, - {"Ohwi4o", tag::Ohwi4o}, - {"Ohwi8o", tag::Ohwi8o}, - {"OIhw16i16o", tag::OIhw16i16o}, - {"OIhw16i32o", tag::OIhw16i32o}, - {"OIhw16i64o", tag::OIhw16i64o}, - {"OIhw16o16i", tag::OIhw16o16i}, - {"Oihw16o", tag::Oihw16o}, - {"OIhw4i16o4i", tag::OIhw4i16o4i}, - {"OIhw4i32o4i", tag::OIhw4i32o4i}, - {"OIhw4i64o4i", tag::OIhw4i64o4i}, - {"OIhw4i4o", tag::OIhw4i4o}, - {"OIhw4o4i", tag::OIhw4o4i}, - {"Oihw4o", tag::Oihw4o}, - {"OIhw8i16o2i", tag::OIhw8i16o2i}, - {"OIhw8i32o2i", tag::OIhw8i32o2i}, - {"OIhw8i64o2i", tag::OIhw8i64o2i}, - {"OIhw8i8o", tag::OIhw8i8o}, - {"OIhw8o16i2o", tag::OIhw8o16i2o}, - {"OIhw8o8i", tag::OIhw8o8i}, - {"OIhw8o4i", tag::OIhw8o4i}, - {"OIhw2i8o4i", tag::OIhw2i8o4i}, - {"IOdhw16o16i", tag::IOdhw16o16i}, - {"Odhwi16o", tag::Odhwi16o}, - {"OdhwI16o2i", tag::OdhwI16o2i}, - {"Odhwi4o", tag::Odhwi4o}, - {"Odhwi8o", tag::Odhwi8o}, - {"OIdhw16i16o", tag::OIdhw16i16o}, - {"OIdhw16i32o", tag::OIdhw16i32o}, - {"OIdhw16i64o", tag::OIdhw16i64o}, - {"OIdhw16o16i", tag::OIdhw16o16i}, - {"Oidhw16o", tag::Oidhw16o}, - {"OIdhw4i4o", tag::OIdhw4i4o}, - {"OIdhw4o4i", tag::OIdhw4o4i}, - {"Oidhw4o", tag::Oidhw4o}, - {"OIdhw8i16o2i", tag::OIdhw8i16o2i}, - {"OIdhw8i32o2i", tag::OIdhw8i32o2i}, - {"OIdhw8i64o2i", tag::OIdhw8i64o2i}, - {"OIdhw4i16o4i", tag::OIdhw4i16o4i}, - {"OIdhw16i16o4i", tag::OIdhw16i16o4i}, - {"OIdhw16i32o4i", tag::OIdhw16i32o4i}, - {"OIdhw16i48o4i", tag::OIdhw16i48o4i}, - {"OIdhw16i64o4i", tag::OIdhw16i64o4i}, - {"OIdhw16i16o2i", tag::OIdhw16i16o2i}, - {"OIdhw16i32o2i", tag::OIdhw16i32o2i}, - {"OIdhw16i48o2i", tag::OIdhw16i48o2i}, - {"OIdhw16i64o2i", tag::OIdhw16i64o2i}, - {"OIdhw4i32o4i", tag::OIdhw4i32o4i}, - {"OIdhw4i64o4i", tag::OIdhw4i64o4i}, - {"OIdhw2i8o4i", tag::OIdhw2i8o4i}, - {"OIdhw8i8o", tag::OIdhw8i8o}, - {"OIdhw8o8i", tag::OIdhw8o8i}, - {"OIdhw8o4i", tag::OIdhw8o4i}, - {"gIOw16o16i", tag::gIOw16o16i}, - {"gOIw16i16o", tag::gOIw16i16o}, - {"gOIw16o16i", tag::gOIw16o16i}, - {"gOiw16o", tag::gOiw16o}, - {"gOIw4i16o4i", tag::gOIw4i16o4i}, - {"gOIw2i8o4i", tag::gOIw2i8o4i}, - {"gOIw4i4o", tag::gOIw4i4o}, - {"gOIw4o4i", tag::gOIw4o4i}, - {"gOiw4o", tag::gOiw4o}, - {"gOIw8i16o2i", tag::gOIw8i16o2i}, - {"gOIw8i8o", tag::gOIw8i8o}, - {"gOIw8o16i2o", tag::gOIw8o16i2o}, - {"gOIw8o8i", tag::gOIw8o8i}, - {"gOIw8o4i", tag::gOIw8o4i}, - {"gOIw16i16o4i", tag::gOIw16i16o4i}, - {"gOIw16i16o2i", tag::gOIw16i16o2i}, - {"gOIw16o16i2o", tag::gOIw16o16i2o}, - {"gOwi16o", tag::gOwi16o}, - {"gOwI16o2i", tag::gOwI16o2i}, - {"gOwi4o", tag::gOwi4o}, - {"gOwi8o", tag::gOwi8o}, - {"Goiw8g", tag::Goiw8g}, - {"Goiw16g", tag::Goiw16g}, - {"gIOhw16o16i", tag::gIOhw16o16i}, - {"gOhwi16o", tag::gOhwi16o}, - {"gOhwI16o2i", tag::gOhwI16o2i}, - {"gOhwi4o", tag::gOhwi4o}, - {"gOhwi8o", tag::gOhwi8o}, - {"Goihw16g", tag::Goihw16g}, - {"gOIhw16i16o", tag::gOIhw16i16o}, - {"gOIhw16o16i", tag::gOIhw16o16i}, - {"gOihw16o", tag::gOihw16o}, - {"gOIhw4i16o4i", tag::gOIhw4i16o4i}, - {"gOIhw2i8o4i", tag::gOIhw2i8o4i}, - {"gOIhw4i4o", tag::gOIhw4i4o}, - {"gOIhw4o4i", tag::gOIhw4o4i}, - {"gOihw4o", tag::gOihw4o}, - {"Goihw8g", tag::Goihw8g}, - {"gOIhw8i16o2i", tag::gOIhw8i16o2i}, - {"gOIhw8i8o", tag::gOIhw8i8o}, - {"gOIhw8o16i2o", tag::gOIhw8o16i2o}, - {"OIw4o8i8o4i", tag::OIw4o8i8o4i}, - {"OIdhw4o8i8o4i", tag::OIdhw4o8i8o4i}, - {"OIhw4o8i8o4i", tag::OIhw4o8i8o4i}, - {"OIhw2o8i8o2i", tag::OIhw2o8i8o2i}, - {"gOIw4o8i8o4i", tag::gOIw4o8i8o4i}, - {"gOIdhw4o8i8o4i", tag::gOIdhw4o8i8o4i}, - {"gOIhw4o8i8o4i", tag::gOIhw4o8i8o4i}, - {"gOIhw2o8i8o2i", tag::gOIhw2o8i8o2i}, - {"OIhw16i16o4i", tag::OIhw16i16o4i}, - {"OIhw16i32o4i", tag::OIhw16i32o4i}, - {"OIhw16i48o4i", tag::OIhw16i48o4i}, - {"OIhw16i64o4i", tag::OIhw16i64o4i}, - {"OIhw16i16o2i", tag::OIhw16i16o2i}, - {"OIhw16i32o2i", tag::OIhw16i32o2i}, - {"OIhw16i48o2i", tag::OIhw16i48o2i}, - {"OIhw16i64o2i", tag::OIhw16i64o2i}, - {"OIhw16o16i2o", tag::OIhw16o16i2o}, - {"gOIhw16i16o4i", tag::gOIhw16i16o4i}, - {"gOIhw16i16o2i", tag::gOIhw16i16o2i}, - {"gOIhw16o16i2o", tag::gOIhw16o16i2o}, - {"gOIhw8o8i", tag::gOIhw8o8i}, - {"gOIhw8o4i", tag::gOIhw8o4i}, - {"gIOdhw16i16o", tag::gIOdhw16i16o}, - {"gIOdhw16o16i", tag::gIOdhw16o16i}, - {"gOdhwi16o", tag::gOdhwi16o}, - {"gOdhwI16o2i", tag::gOdhwI16o2i}, - {"gOdhwi4o", tag::gOdhwi4o}, - {"gOdhwi8o", tag::gOdhwi8o}, - {"gOIdhw16i16o", tag::gOIdhw16i16o}, - {"gOIdhw16o16i", tag::gOIdhw16o16i}, - {"gOidhw16o", tag::gOidhw16o}, - {"gOIdhw4i4o", tag::gOIdhw4i4o}, - {"gOIdhw4o4i", tag::gOIdhw4o4i}, - {"gOidhw4o", tag::gOidhw4o}, - {"gOIdhw8i16o2i", tag::gOIdhw8i16o2i}, - {"gOIdhw4i16o4i", tag::gOIdhw4i16o4i}, - {"gOIdhw16i16o4i", tag::gOIdhw16i16o4i}, - {"gOIdhw16i16o2i", tag::gOIdhw16i16o2i}, - {"gOIdhw2i8o4i", tag::gOIdhw2i8o4i}, - {"gOIdhw8i8o", tag::gOIdhw8i8o}, - {"gOIdhw8o8i", tag::gOIdhw8o8i}, - {"gOIdhw8o4i", tag::gOIdhw8o4i}, - {"gOIw2i4o2i", tag::gOIw2i4o2i}, - {"gOIhw2i4o2i", tag::gOIhw2i4o2i}, - {"gOIdhw2i4o2i", tag::gOIdhw2i4o2i}, - {"gOIw2o4i2o", tag::gOIw2o4i2o}, - {"gOIhw2o4i2o", tag::gOIhw2o4i2o}, - {"gOIdhw2o4i2o", tag::gOIdhw2o4i2o}, - {"gOIw4i8o2i", tag::gOIw4i8o2i}, - {"gOIhw4i8o2i", tag::gOIhw4i8o2i}, - {"gOIdhw4i8o2i", tag::gOIdhw4i8o2i}, - {"gOIw4o8i2o", tag::gOIw4o8i2o}, - {"gOIhw4o8i2o", tag::gOIhw4o8i2o}, - {"gOIdhw4o8i2o", tag::gOIdhw4o8i2o}, - {"ldOi32o", tag::ldOi32o}, - {"ldOI32o4i", tag::ldOI32o4i}, - {"ldgOi32o", tag::ldgOi32o}, - {"ldgOI32o2i", tag::ldgOI32o2i}, - {"ldgOI32o4i", tag::ldgOI32o4i}, - {"OwI16o4i", tag::OwI16o4i}, - {"OhwI16o4i", tag::OhwI16o4i}, - {"gOwI16o4i", tag::gOwI16o4i}, - {"gOhwI16o4i", tag::gOhwI16o4i}, - {"OdhwI16o4i", tag::OdhwI16o4i}, - {"gOdhwI16o4i", tag::gOdhwI16o4i}, - {"Owi32o", tag::Owi32o}, - {"OwI32o2i", tag::OwI32o2i}, - {"OwI32o4i", tag::OwI32o4i}, - {"Owi48o", tag::Owi48o}, - {"OwI48o2i", tag::OwI48o2i}, - {"OwI48o4i", tag::OwI48o4i}, - {"Owi64o", tag::Owi64o}, - {"OwI64o2i", tag::OwI64o2i}, - {"OwI64o4i", tag::OwI64o4i}, - {"wIo2i", tag::wIo2i}, - {"wIo4i", tag::wIo4i}, - {"gOwi32o", tag::gOwi32o}, - {"gOwI32o2i", tag::gOwI32o2i}, - {"gOwI32o4i", tag::gOwI32o4i}, - {"gOwi48o", tag::gOwi48o}, - {"gOwI48o2i", tag::gOwI48o2i}, - {"gOwI48o4i", tag::gOwI48o4i}, - {"gOwi64o", tag::gOwi64o}, - {"gOwI64o2i", tag::gOwI64o2i}, - {"gOwI64o4i", tag::gOwI64o4i}, - {"gwio", tag::gwio}, - {"gwIo2i", tag::gwIo2i}, - {"gwIo4i", tag::gwIo4i}, - {"OhwI32o", tag::OhwI32o}, - {"OhwI32o2i", tag::OhwI32o2i}, - {"OhwI32o4i", tag::OhwI32o4i}, - {"Ohwi48o", tag::Ohwi48o}, - {"OhwI48o2i", tag::OhwI48o2i}, - {"OhwI48o4i", tag::OhwI48o4i}, - {"Ohwi64o", tag::Ohwi64o}, - {"OhwI64o2i", tag::OhwI64o2i}, - {"OhwI64o4i", tag::OhwI64o4i}, - {"hwIo2i", tag::hwIo2i}, - {"hwIo4i", tag::hwIo4i}, - {"gOhwI32o", tag::gOhwI32o}, - {"gOhwI32o2i", tag::gOhwI32o2i}, - {"gOhwI32o4i", tag::gOhwI32o4i}, - {"gOhwi48o", tag::gOhwi48o}, - {"gOhwI48o2i", tag::gOhwI48o2i}, - {"gOhwI48o4i", tag::gOhwI48o4i}, - {"gOhwi64o", tag::gOhwi64o}, - {"gOhwI64o2i", tag::gOhwI64o2i}, - {"gOhwI64o4i", tag::gOhwI64o4i}, - {"ghwio", tag::ghwio}, - {"ghwIo2i", tag::ghwIo2i}, - {"ghwIo4i", tag::ghwIo4i}, - {"Odhwi32o", tag::Odhwi32o}, - {"OdhwI32o2i", tag::OdhwI32o2i}, - {"OdhwI32o4i", tag::OdhwI32o4i}, - {"Odhwi48o", tag::Odhwi48o}, - {"OdhwI48o2i", tag::OdhwI48o2i}, - {"OdhwI48o4i", tag::OdhwI48o4i}, - {"Odhwi64o", tag::Odhwi64o}, - {"OdhwI64o2i", tag::OdhwI64o2i}, - {"OdhwI64o4i", tag::OdhwI64o4i}, - {"dhwIo2i", tag::dhwIo2i}, - {"dhwIo4i", tag::dhwIo4i}, - {"gOdhwi32o", tag::gOdhwi32o}, - {"gOdhwI32o2i", tag::gOdhwI32o2i}, - {"gOdhwI32o4i", tag::gOdhwI32o4i}, - {"gOdhwi48o", tag::gOdhwi48o}, - {"gOdhwI48o2i", tag::gOdhwI48o2i}, - {"gOdhwI48o4i", tag::gOdhwI48o4i}, - {"gOdhwi64o", tag::gOdhwi64o}, - {"gOdhwI64o2i", tag::gOdhwI64o2i}, - {"gOdhwI64o4i", tag::gOdhwI64o4i}, - {"gdhwio", tag::gdhwio}, - {"gdhwIo2i", tag::gdhwIo2i}, - {"gdhwIo4i", tag::gdhwIo4i}, - {"ldIo32i", tag::ldIo32i}, - {"ldgIo32i", tag::ldgIo32i}, - {"ldgIO32i2o", tag::ldgIO32i2o}, - {"nCdhw32c", tag::nCdhw32c}, - {"nChw32c", tag::nChw32c}, - {"nCw32c", tag::nCw32c}, - {"NCw32n16c", tag::NCw32n16c}, - {"NChw32n16c", tag::NChw32n16c}, - {"NCdhw32n16c", tag::NCdhw32n16c}, - {"NCw32n32c", tag::NCw32n32c}, - {"OI16i16o4i", tag::OI16i16o4i}, - {"IOw8o16i2o", tag::IOw8o16i2o}, - {"IOhw8o16i2o", tag::IOhw8o16i2o}, - {"Owhi16o", tag::Owhi16o}, - {"OIdhw8o16i2o", tag::OIdhw8o16i2o}, - {"IOdhw8o16i2o", tag::IOdhw8o16i2o}, - {"Goiw4g", tag::Goiw4g}, - {"gIOw8o16i2o", tag::gIOw8o16i2o}, - {"Goiw32g", tag::Goiw32g}, - {"Goihw4g", tag::Goihw4g}, - {"gIOhw8o16i2o", tag::gIOhw8o16i2o}, - {"Goihw32g", tag::Goihw32g}, - {"gOwhi16o", tag::gOwhi16o}, - {"IOw4i8o8i4o", tag::IOw4i8o8i4o}, - {"IOhw4i8o8i4o", tag::IOhw4i8o8i4o}, - {"IOdhw4i8o8i4o", tag::IOdhw4i8o8i4o}, - {"gIOw4i8o8i4o", tag::gIOw4i8o8i4o}, - {"gIOhw4i8o8i4o", tag::gIOhw4i8o8i4o}, - {"gIOdhw4i8o8i4o", tag::gIOdhw4i8o8i4o}, - {"gOIdhw8o16i2o", tag::gOIdhw8o16i2o}, - {"gIOdhw8o16i2o", tag::gIOdhw8o16i2o}, - {"Goidhw32g", tag::Goidhw32g}, - {"OI16i32o4i", tag::OI16i32o4i}, - {"OI16i48o4i", tag::OI16i48o4i}, - {"OI16i64o4i", tag::OI16i64o4i}, - {"OI16i16o2i", tag::OI16i16o2i}, - {"OI16i32o2i", tag::OI16i32o2i}, - {"OI16i48o2i", tag::OI16i48o2i}, - {"OI16i64o2i", tag::OI16i64o2i}, - {"OwI16i16o2i", tag::OwI16i16o2i}, - {"gOwI16i16o2i", tag::gOwI16i16o2i}, - {"OhwI16i16o2i", tag::OhwI16i16o2i}, - {"gOhwI16i16o2i", tag::gOhwI16i16o2i}, - {"OdhwI16i16o2i", tag::OdhwI16i16o2i}, - {"gOdhwI16i16o2i", tag::gOdhwI16i16o2i}, - {"OwI16i16o4i", tag::OwI16i16o4i}, - {"gOwI16i16o4i", tag::gOwI16i16o4i}, - {"OhwI16i16o4i", tag::OhwI16i16o4i}, - {"gOhwI16i16o4i", tag::gOhwI16i16o4i}, - {"OdhwI16i16o4i", tag::OdhwI16i16o4i}, - {"gOdhwI16i16o4i", tag::gOdhwI16i16o4i}, - {"OwI16i32o2i", tag::OwI16i32o2i}, - {"OwI16i32o4i", tag::OwI16i32o4i}, - {"OwI16i48o2i", tag::OwI16i48o2i}, - {"OwI16i48o4i", tag::OwI16i48o4i}, - {"OwI16i64o2i", tag::OwI16i64o2i}, - {"OwI16i64o4i", tag::OwI16i64o4i}, - {"gOwI16i32o2i", tag::gOwI16i32o2i}, - {"gOwI16i32o4i", tag::gOwI16i32o4i}, - {"gOwI16i48o2i", tag::gOwI16i48o2i}, - {"gOwI16i48o4i", tag::gOwI16i48o4i}, - {"gOwI16i64o2i", tag::gOwI16i64o2i}, - {"gOwI16i64o4i", tag::gOwI16i64o4i}, - {"OhwI16i32o2i", tag::OhwI16i32o2i}, - {"OhwI16i32o4i", tag::OhwI16i32o4i}, - {"OhwI16i48o2i", tag::OhwI16i48o2i}, - {"OhwI16i48o4i", tag::OhwI16i48o4i}, - {"OhwI16i64o2i", tag::OhwI16i64o2i}, - {"OhwI16i64o4i", tag::OhwI16i64o4i}, - {"gOhwI16i32o2i", tag::gOhwI16i32o2i}, - {"gOhwI16i32o4i", tag::gOhwI16i32o4i}, - {"gOhwI16i48o2i", tag::gOhwI16i48o2i}, - {"gOhwI16i48o4i", tag::gOhwI16i48o4i}, - {"gOhwI16i64o2i", tag::gOhwI16i64o2i}, - {"gOhwI16i64o4i", tag::gOhwI16i64o4i}, - {"OdhwI16i32o2i", tag::OdhwI16i32o2i}, - {"OdhwI16i32o4i", tag::OdhwI16i32o4i}, - {"OdhwI16i48o2i", tag::OdhwI16i48o2i}, - {"OdhwI16i48o4i", tag::OdhwI16i48o4i}, - {"OdhwI16i64o2i", tag::OdhwI16i64o2i}, - {"OdhwI16i64o4i", tag::OdhwI16i64o4i}, - {"gOdhwI16i32o2i", tag::gOdhwI16i32o2i}, - {"gOdhwI16i32o4i", tag::gOdhwI16i32o4i}, - {"gOdhwI16i48o2i", tag::gOdhwI16i48o2i}, - {"gOdhwI16i48o4i", tag::gOdhwI16i48o4i}, - {"gOdhwI16i64o2i", tag::gOdhwI16i64o2i}, - {"gOdhwI16i64o4i", tag::gOdhwI16i64o4i}, - {"hwioG16g", tag::hwioG16g}, - {"NCdhw40n32c", tag::NCdhw40n32c}, - {"NChw40n32c", tag::NChw40n32c}, - {"NCw40n32c", tag::NCw40n32c}, - {"OIdhw4o8i8o2i", tag::OIdhw4o8i8o2i}, - {"OIhw4o8i8o2i", tag::OIhw4o8i8o2i}, - {"OIw4o8i8o2i", tag::OIw4o8i8o2i}, - {"gOIdhw4o8i8o2i", tag::gOIdhw4o8i8o2i}, - {"gOIhw4o8i8o2i", tag::gOIhw4o8i8o2i}, - {"gOIw4o8i8o2i", tag::gOIw4o8i8o2i}, - {"IOdhw4i8o8i2o", tag::IOdhw4i8o8i2o}, - {"IOhw4i8o8i2o", tag::IOhw4i8o8i2o}, - {"IOw4i8o8i2o", tag::IOw4i8o8i2o}, - {"gIOdhw4i8o8i2o", tag::gIOdhw4i8o8i2o}, - {"gIOhw4i8o8i2o", tag::gIOhw4i8o8i2o}, - {"gIOw4i8o8i2o", tag::gIOw4i8o8i2o}, - {"NCdhw40n16c", tag::NCdhw40n16c}, - {"NCw40n16c", tag::NCw40n16c}, - {"NChw40n16c", tag::NChw40n16c}, - {"NCw2c32n8c", tag::NCw2c32n8c}, - {"NChw2c32n8c", tag::NChw2c32n8c}, - {"NCdhw2c32n8c", tag::NCdhw2c32n8c}, - {"OIw2i8o16i4o", tag::OIw2i8o16i4o}, - {"OIhw2i8o16i4o", tag::OIhw2i8o16i4o}, - {"OIdhw2i8o16i4o", tag::OIdhw2i8o16i4o}, - {"OIw2o8i16o4i", tag::OIw2o8i16o4i}, - {"OIw2o8i16o2i", tag::OIw2o8i16o2i}, - {"IOw2i8o16i4o", tag::IOw2i8o16i4o}, - {"IOw2i8o16i2o", tag::IOw2i8o16i2o}, - {"OIhw2o8i16o4i", tag::OIhw2o8i16o4i}, - {"OIhw2o8i16o2i", tag::OIhw2o8i16o2i}, - {"IOhw2i8o16i4o", tag::IOhw2i8o16i4o}, - {"IOhw2i8o16i2o", tag::IOhw2i8o16i2o}, - {"OIdhw2o8i16o4i", tag::OIdhw2o8i16o4i}, - {"OIdhw2o8i16o2i", tag::OIdhw2o8i16o2i}, - {"IOdhw2i8o16i4o", tag::IOdhw2i8o16i4o}, - {"IOdhw2i8o16i2o", tag::IOdhw2i8o16i2o}, - {"gOIw2o8i16o2i", tag::gOIw2o8i16o2i}, - {"gIOw2i8o16i2o", tag::gIOw2i8o16i2o}, - {"gIOhw2i8o16i2o", tag::gIOhw2i8o16i2o}, - {"gIOdhw2i8o16i2o", tag::gIOdhw2i8o16i2o}, - {"gOIhw2o8i16o2i", tag::gOIhw2o8i16o2i}, - {"gOIdhw2o8i16o2i", tag::gOIdhw2o8i16o2i}, - {"gOIw2o8i16o4i", tag::gOIw2o8i16o4i}, - {"gOIhw2o8i16o4i", tag::gOIhw2o8i16o4i}}; - std::string key = ""; - for (const auto& c : layout) { - if (std::isalpha(c, std::locale("C"))) { - char lower_c = std::tolower(c); - if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) { - key.push_back(c); - } else { - key.push_back(lower_c); - } - } else if (std::isdigit(c)) { - key.push_back(c); - } else { - LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl; - } - } - if (str2tag.count(key) == 0) { - LOG(WARNING) << "convert unregistered layout '" << key << "' to tag::any"; - return tag::any; + /* Override GetFunction to reimplement Run method */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + + ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + << "Found mismatch in the number of provided data entries and required."; + + Run(args); + }); } else { - return str2tag.at(key); + return JSONRuntimeBase::GetFunction(name, sptr_to_self); + } + } + + /* Same as makeInitDataProvider but in case of InputOutput return real DLTensor */ + TensorRegistry::DLTensorProvider makeIODataProvider(const TVMArgs& args) const { + auto extract_dl_tensor = [](const TVMArgValue& val) -> const DLTensor* { + ICHECK(val.type_code() == kTVMNDArrayHandle || val.type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor"; + return val.IsObjectRef() ? val.operator NDArray().operator->() + : val.operator DLTensor*(); + }; + + std::map io_map; // eid to dl tensor map + for (size_t i = 0; i < run_arg_eid_.size(); i++) { + io_map[run_arg_eid_[i]] = extract_dl_tensor(args[i]); } + + // lambda with captured IO data handlers + return [io_map](uint32_t eid) -> const DLTensor* { return io_map.at(eid); }; } - std::map elt_name2algo{ + private: + const std::map elt_name2algo{ {"abs", dnnl::algorithm::eltwise_abs}, {"exp", dnnl::algorithm::eltwise_exp}, {"log", dnnl::algorithm::eltwise_log}, @@ -626,64 +161,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return std::regex_match(op_name, bias_add_pat) ? true : false; } - dnnl::memory::dims TransDims2Plain(dnnl::memory::dims input_dims, std::string layout) { - std::vector axis = { - 'N', 'C', 'O', 'I', 'D', 'H', 'W', - }; - dnnl::memory::dims out_dims; - std::string::iterator t = layout.begin(); - // Remove numbers in layout string to match the size of input_dims - while (t != layout.end()) { - if (*t >= '0' && *t <= '9') { - layout.erase(t); - } else { - t++; - } - } - // Push the correct shapes of each axis into the output_dims - for (auto a : axis) { - if (layout.find(a) != std::string::npos) { - dnnl::memory::dim shape = input_dims[layout.find(a)]; - char lower_a = std::tolower(a); - for (size_t i = 0; i < layout.size(); ++i) { - if (lower_a == layout[i]) { - shape *= input_dims[i]; - } - } - out_dims.push_back(shape); - } - } - // Multiply O and I with G, respectively - if (layout.find("G") != std::string::npos) { - dnnl::memory::dim G = 1; - if (layout.find("g") != std::string::npos) { - G = input_dims[layout.find("g")] * input_dims[layout.find("G")]; - } else { - G = input_dims[layout.find("G")]; - } - out_dims[0] *= G; - out_dims[1] *= G; - } - return out_dims; - } - - dnnl::memory::dims TransformStr2Dims(std::vector strs, bool dilates = false) { - dnnl::memory::dims out_dims; - if (dilates) { - std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str) - 1; }); - } else { - std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str); }); - } - return out_dims; - } - // Build up the engine based on the input graph. void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); + std::set io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end()); + tensor_registry_ = TensorRegistry(engine_, io_eid_set); + std::regex conv_pat(".*conv[1-3]d.*"); std::regex deconv_pat(".*deconv[1-3]d.*"); std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); @@ -725,562 +210,471 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } } - // Bind a JSON graph node entry to a DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc, - size_t offset = 0) { - auto eid = EntryID(entry); - if (entry_out_mem_.count(eid) == 0) { - return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset); - } - return entry_out_mem_[eid].first; - } - - // Bind a JSON graph node entry to a given DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem, - size_t offset = 0) { - auto eid = EntryID(entry); - // Since the DNNL memory has been created before calling this function, we assume the entry - // has not yet been bound to the other DNNL memory; otherwise it may have memory leak. - ICHECK_EQ(entry_out_mem_.count(eid), 0); - - entry_out_mem_[eid] = {mem, offset}; - return entry_out_mem_[eid].first; - } - void Convolution(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim channels = - node.GetAttr>("channels")[0] != "" - ? std::stoi(node.GetAttr>("channels")[0]) - : out_shape[1]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - std::string data_layout = node.GetAttr>("data_layout")[0]; - std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - - // Memory shapes. - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); - dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - dnnl::memory::dims dst_dims = src_dims; - dst_dims[1] = channels; - weights_dims_[0] = channels; - weights_dims_[1] = src_dims[1]; - for (size_t i = 2; i < src_dims.size(); i++) { - dnnl::memory::dim K = weights_dims_[i]; - dnnl::memory::dim S = strides_dims[i - 2]; - dnnl::memory::dim D = dilates_dims[i - 2]; - dnnl::memory::dim PL = padding_dims_l[i - 2]; - dnnl::memory::dim PR = padding_dims_r[i - 2]; - dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); - dst_dims[i] = (src_dims[i] - DK + PL + PR) / S + 1; + auto src_tr = GetInput(nid, 0); + auto wgh_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); + auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + auto strides = GetNodeAttr>(node, "strides"); + auto dilates = GetNodeAttr>(node, "dilation"); + auto padding = GetNodeAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto groups = GetNodeAttr(node, "groups"); + auto src_layout = GetNodeAttr(node, "data_layout"); + auto dst_layout = GetNodeAttr(node, "out_layout"); + auto wgh_layout = GetNodeAttr(node, "kernel_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // Take into account provided layout strings + src_tr = src_tr.TreatAs(src_layout); + dst_tr = dst_tr.TreatAs(dst_layout); + wgh_tr = wgh_tr.TreatAs(wgh_layout); + + // Should support G mixed with O. Like { G*O, I, H, W } + // Use { G, O, I, H, W } weight format even if groups == 1 + if (wgh_layout.find("G") == std::string::npos) { + auto w_dims = wgh_tr.dims(); + w_dims[0] /= groups; + w_dims.insert(w_dims.begin(), groups); + wgh_tr = wgh_tr.Reshape(w_dims); } - dnnl::memory::dims weights_dims = weights_dims_; - if (groups > 1) { - weights_dims = {groups, channels / groups, src_dims[1] / groups}; - weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); - if (kernel_layout == "OIHW") { - kernel_layout.insert(0, "G"); - } + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); + + // TODO(@apeskov): This is WA. In case of padded blocked tensor format we do not know original + // shapes. Example tensor {1, 10, 224, 224} with layout "NCNH8c" will lead to tensor + // {1, 2, 224, 224, 8}. Identically as for shapes {1, 11, 224, 224} or {1, 15, 224, 224}. + // + // Let's try to compensate it for weight tensor. Weight IC should match with source IC. + // Example src: [1, 3, 224, 224] with layout NCHW + // wgh: [16, 3, 3, 3] with layout OIHW2i8o -> [2, 2, 3, 3, 2, 8] + if (wgh_tr.dims()[2] != src_tr.dims()[1] / groups) { + auto wgh_croped_dims = wgh_tr.dims(); + wgh_croped_dims[2] = src_tr.dims()[1]; + auto zero_offset = dnnl::memory::dims(wgh_tr.dims().size(), 0); + wgh_tr = wgh_tr.Crop(wgh_croped_dims, zero_offset); } - // Memory descriptions. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - auto conv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout)); - auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout)); - auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); - // Conv description. - auto conv_desc = - has_bias ? dnnl::convolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, - conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, - dilates_dims, padding_dims_l, padding_dims_r) - : dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, - dnnl::algorithm::convolution_direct, conv_src_md, - conv_weights_md, conv_dst_md, strides_dims, - dilates_dims, padding_dims_l, padding_dims_r); + auto conv_desc = dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, + src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), bias_tr.LayoutAny().desc(), + dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r); // Enable elementwise post-ops. auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); - // Push to the network. - auto conv = dnnl::convolution_forward(conv_prim_desc); - net_.push_back(conv); - - // Data memory. - auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md); + src_tr = src_tr.RequestLayout(conv_prim_desc.src_desc()); + wgh_tr = wgh_tr.RequestLayout(conv_prim_desc.weights_desc()); + dst_tr = dst_tr.RequestLayout(conv_prim_desc.dst_desc()); + bias_tr = bias_tr.RequestLayout(conv_prim_desc.bias_desc()); - // Weight memory. - auto conv_weights_memory = BindDNNLMemory(weight_entry, conv_prim_desc.weights_desc()); + auto scratchpad_tr = TensorRequisite::AsIs(conv_prim_desc.scratchpad_desc()); - // Output memory. - auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc()); - - // Bias memory. - auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, conv_bias_memory); - - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } else { - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } + Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void Deconvolution(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim channels = - node.GetAttr>("channels")[0] != "" - ? std::stoi(node.GetAttr>("channels")[0]) - : out_shape[1]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - std::vector str_out_padding = - node.GetAttr>("output_padding"); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - std::string data_layout = node.GetAttr>("data_layout")[0]; - std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - - // Memory shapes. - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); - dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); - // legalize shape IOHW with layout OIHW - if (weights_dims_[0] == src_dims[1] && weights_dims_[1] == channels) { - std::swap(weights_dims_[0], weights_dims_[1]); - if (kernel_layout.find("OI") == 0) { - kernel_layout.replace(kernel_layout.find("OI"), 2, "IO"); - } - } - weights_dims_[0] = channels; - weights_dims_[1] = src_dims[1]; - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding); - dnnl::memory::dims dst_dims = src_dims; - dst_dims[1] = channels; - for (size_t i = 2; i < src_dims.size(); i++) { - dnnl::memory::dim K = weights_dims_[i]; - dnnl::memory::dim S = strides_dims[i - 2]; - dnnl::memory::dim D = dilates_dims[i - 2]; - dnnl::memory::dim PL = padding_dims_l[i - 2]; - dnnl::memory::dim PR = padding_dims_r[i - 2]; - dnnl::memory::dim OP = out_padding[i - 2]; - dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); - dst_dims[i] = S * (src_dims[i] - 1) + DK - PL - PR + OP; + auto src_tr = GetInput(nid, 0); + auto wgh_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); + auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + + auto strides = GetNodeAttr>(node, "strides"); + auto dilates = GetNodeAttr>(node, "dilation"); + auto padding = GetNodeAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto groups = GetNodeAttr(node, "groups"); + auto src_layout = GetNodeAttr(node, "data_layout"); + auto dst_layout = GetNodeAttr(node, "out_layout"); + auto wgh_layout = GetNodeAttr(node, "kernel_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // TODO(@apeskov): WA. conv3dTranspose uses wrong layout specifier. IO instead of OI. + auto wgh_logic_layout = TensorRequisite::DefaultLogicLayoutFor(wgh_layout); + if (wgh_logic_layout == "OIDHW") wgh_logic_layout = "IODHW"; + if (wgh_logic_layout == "GOIDHW") wgh_logic_layout = "GIODHW"; + + // Take into account provided layout strings + src_tr = src_tr.TreatAs(src_layout); + dst_tr = dst_tr.TreatAs(dst_layout); + wgh_tr = wgh_tr.TreatAs(wgh_layout, wgh_logic_layout); + + // Should support G mixed with O. Like { G*O, I, H, W } + if (wgh_layout.find("G") == std::string::npos) { + auto w_dims = wgh_tr.dims(); + w_dims[0] /= groups; + w_dims.insert(w_dims.begin(), groups); + wgh_tr = wgh_tr.Reshape(w_dims); } - dnnl::memory::dims weights_dims = weights_dims_; - if (groups > 1) { - weights_dims = {groups, channels / groups, src_dims[1] / groups}; - weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); - } + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); - // Memory descriptions. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - auto deconv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout)); - auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout)); - auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x); - auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); - - // Transposed covn2d description. - auto deconv_desc = - has_bias ? dnnl::deconvolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - deconv_src_md, deconv_weights_md, deconv_bias_md, deconv_dst_md, - strides_dims, dilates_dims, padding_dims_l, padding_dims_r) - : dnnl::deconvolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - deconv_src_md, deconv_weights_md, deconv_dst_md, strides_dims, dilates_dims, - padding_dims_l, padding_dims_r); + // Conv description. + auto deconv_desc = dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), bias_tr.LayoutAny().desc(), + dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r); // Enable elementwise post-ops. auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); - // Push to the network. - auto deconv = dnnl::deconvolution_forward(deconv_prim_desc); - net_.push_back(deconv); - - // Data memory. - auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md); - - // Weight memory. - auto deconv_weights_memory = BindDNNLMemory(weight_entry, deconv_prim_desc.weights_desc()); - - // Output memory. - auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc()); + src_tr = src_tr.RequestLayout(deconv_prim_desc.src_desc()); + wgh_tr = wgh_tr.RequestLayout(deconv_prim_desc.weights_desc()); + dst_tr = dst_tr.RequestLayout(deconv_prim_desc.dst_desc()); + bias_tr = bias_tr.RequestLayout(deconv_prim_desc.bias_desc()); - // Bias memory. - auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, deconv_bias_memory); + auto scratchpad_tr = TensorRequisite::AsIs(deconv_prim_desc.scratchpad_desc()); - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, - {DNNL_ARG_WEIGHTS, deconv_weights_memory}, - {DNNL_ARG_BIAS, deconv_bias_memory}, - {DNNL_ARG_DST, deconv_dst_memory}}); - } else { - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, - {DNNL_ARG_WEIGHTS, deconv_weights_memory}, - {DNNL_ARG_DST, deconv_dst_memory}}); - } + Submit(dnnl::deconvolution_forward(deconv_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void Dense(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim OC = out_shape[1]; - - // Memory shapes. - dnnl::memory::dims data_dims = input_shape; - dnnl::memory::dims weight_dims = weight_shape; - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims out_dims = out_shape; - - // Memory descriptions. - auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; - auto dtype = dtype_dl2dnnl(dl_dtype); - auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc}); - auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc}); - auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x}); - auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc}); + auto src_tr = GetInput(nid, 0); + auto wgh_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); + auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); // Dense description. - auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, - weight_md, bias_md, dst_md); + auto dense_desc = dnnl::inner_product_forward::desc( + dnnl::prop_kind::forward_inference, src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), + bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc()); // Enable elementwise post-ops. auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_); - auto dense = dnnl::inner_product_forward(dense_prim_desc); - net_.push_back(dense); + src_tr = src_tr.RequestLayout(dense_prim_desc.src_desc()); + wgh_tr = wgh_tr.RequestLayout(dense_prim_desc.weights_desc()); + dst_tr = dst_tr.RequestLayout(dense_prim_desc.dst_desc()); + bias_tr = bias_tr.RequestLayout(dense_prim_desc.bias_desc()); - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - auto weight_memory = BindDNNLMemory(weight_entry, weight_md); + auto scratchpad_tr = TensorRequisite::AsIs(dense_prim_desc.scratchpad_desc()); - // Bias memory. - auto bias_memory = dnnl::memory(bias_md, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, bias_memory); - } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8)); - } - - // Output memory. - auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_WEIGHTS, weight_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST, dst_memory}}); + Submit(dnnl::inner_product_forward(dense_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; - auto data_entry = node.GetInputs()[0]; - auto gamma_entry = node.GetInputs()[1]; - auto beta_entry = node.GetInputs()[2]; - auto mean_entry = node.GetInputs()[3]; - auto variance_entry = node.GetInputs()[4]; - dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dim IC = data_shape[1]; - float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + auto src_tr = GetInput(nid, 0); + auto gamma_tr = GetInput(nid, 1); + auto beta_tr = GetInput(nid, 2); + auto mean_tr = GetInput(nid, 3); + auto var_tr = GetInput(nid, 4); + auto dst_tr = GetOutput(nid, 0); - // Memory description. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype); + auto axis = GetNodeAttr(node, "axis"); + auto epsilon = GetNodeAttr(node, "epsilon"); + auto center = GetNodeAttr(node, "center"); + auto scale = GetNodeAttr(node, "scale"); + + ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case"; - // BN description. auto bn_desc = dnnl::batch_normalization_forward::desc( - dnnl::prop_kind::forward_inference, data_md, epsilon, + dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon, dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift); auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_); - auto bn = dnnl::batch_normalization_forward(bn_prim_desc); - net_.push_back(bn); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc()); - auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc()); - - // In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but - // assign an offset to beta data for runtime serialization. - auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0); - BindDNNLMemory(beta_entry, weight_memory, IC); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_DST, out_memory}, - {DNNL_ARG_SCALE_SHIFT, weight_memory}, - {DNNL_ARG_MEAN, mean_memory}, - {DNNL_ARG_VARIANCE, variance_memory}}); + + // Concatenate scale and shift tensors + auto scale_shift_tr = TensorRequisite::AsIs(bn_prim_desc.weights_desc(), GenUniqueEid()); + auto sc_sh_dims = scale_shift_tr.dims(); + ICHECK(sc_sh_dims.size() == 2); + ICHECK(sc_sh_dims[0] == 2); + sc_sh_dims[0] /= 2; + auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze(); + auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze(); + + auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) { + dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc()); + Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); + }; + + register_copy(gamma_tr, scale_tr); + register_copy(beta_tr, shift_tr); + + Submit(dnnl::batch_normalization_forward(bn_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_DST, dst_tr}, + {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}, + {DNNL_ARG_MEAN, mean_tr}, + {DNNL_ARG_VARIANCE, var_tr}}); } void Pooling(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; + auto src_tr = GetInput(nid, 0); + auto dst_tr = GetOutput(nid, 0); + // Setup attributes. - auto data_entry = node.GetInputs()[0]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - std::vector str_kernel = node.GetAttr>("pool_size"); - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - std::vector str_dilates = node.GetAttr>("dilation"); - std::string layout = node.GetAttr>("layout")[0]; + auto strides = GetNodeAttr>(node, "strides"); + auto dilates = GetNodeAttr>(node, "dilation"); + auto padding = GetNodeAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto kernel = GetNodeAttr>(node, "pool_size"); + auto src_layout = GetNodeAttr(node, "layout"); + auto dst_layout = GetNodeAttr(node, "out_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // Take into account provided layout strings + src_tr = src_tr.TreatAs(src_layout); + dst_tr = dst_tr.TreatAs(dst_layout); // Attributes related to AvgPool if (algo == dnnl::algorithm::pooling_avg) { - int int_countpad = std::stoi(node.GetAttr>("count_include_pad")[0]); - bool count_include_pad = int_countpad != 0 ? true : false; - algo = count_include_pad ? dnnl::algorithm::pooling_avg_include_padding - : dnnl::algorithm::pooling_avg_exclude_padding; + auto include_pad = GetNodeAttr(node, "count_include_pad"); + algo = include_pad ? dnnl::algorithm::pooling_avg_include_padding + : dnnl::algorithm::pooling_avg_exclude_padding; } - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout); - dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout); - dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel); - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - - // Memory descriptions. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout)); - auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); - // Pooling description. - auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo, - pool_src_md, pool_dst_md, strides_dims, - kernel_dims, padding_dims_l, padding_dims_r); - - auto pool_prim_desc = dnnl::pooling_forward::primitive_desc(pool_desc, engine_, true); - auto pool = dnnl::pooling_forward(pool_prim_desc); - net_.push_back(pool); + auto pool_desc = dnnl::pooling_v2_forward::desc( + dnnl::prop_kind::forward_inference, algo, src_tr.desc(), //<= Do not use any for src tensor + dst_tr.LayoutAny().desc(), strides, kernel, dilates, padding_l, padding_r); + auto pool_prim_desc = dnnl::pooling_v2_forward::primitive_desc(pool_desc, engine_); - // Memories. - auto pool2d_src_memory = BindDNNLMemory(data_entry, pool_src_md); + src_tr = src_tr.RequestLayout(pool_prim_desc.src_desc()); + dst_tr = dst_tr.RequestLayout(pool_prim_desc.dst_desc()); - auto pool2d_dst_memory = BindDNNLMemory(out_entry, pool_prim_desc.dst_desc()); + auto scratchpad_tr = TensorRequisite::AsIs(pool_prim_desc.scratchpad_desc()); - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, pool2d_src_memory}, {DNNL_ARG_DST, pool2d_dst_memory}}); + Submit(dnnl::pooling_v2_forward(pool_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCRATCHPAD, scratchpad_tr}}); } void Eltwise(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); - auto algo = elt_name2algo[op_name]; + auto algo = elt_name2algo.at(op_name); + + auto src_tr = GetInput(nid, 0); + auto dst_tr = GetOutput(nid, 0); - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype); float alpha = 0., beta = 0.; if (op_name == "clip") { - alpha = std::stof(node.GetAttr>("a_min")[0]); - beta = std::stof(node.GetAttr>("a_max")[0]); + alpha = GetNodeAttr(node, "a_min"); + beta = GetNodeAttr(node, "a_max"); } else if (op_name == "nn.leaky_relu") { - alpha = std::stof(node.GetAttr>("alpha")[0]); + alpha = GetNodeAttr(node, "alpha"); } - auto elt_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, alpha, beta); + auto elt_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, + src_tr.desc(), alpha, beta); auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); - ICHECK(data_md == elt_prim_desc.dst_desc()); - - auto elt = dnnl::eltwise_forward(elt_prim_desc); - net_.push_back(elt); + ICHECK(src_tr.desc() == elt_prim_desc.dst_desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + Submit(dnnl::eltwise_forward(elt_prim_desc), {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } void Softmax(const size_t& nid) { auto node = nodes_[nid]; - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - int axis = std::stoi(node.GetAttr>("axis")[0]); + auto src_tr = GetInput(nid, 0); + auto dst_tr = GetOutput(nid, 0); + + auto axis = GetNodeAttr(node, "axis"); if (axis < 0) { - axis = shape.size() + axis; + axis = src_tr.dims().size() + axis; } - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype); auto softmax_desc = - dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis); + dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, src_tr.desc(), axis); auto softmax_prim_desc = dnnl::softmax_forward::primitive_desc(softmax_desc, engine_); - ICHECK(data_md == softmax_prim_desc.dst_desc()); - - auto softmax = dnnl::softmax_forward(softmax_prim_desc); - net_.push_back(softmax); + ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + Submit(dnnl::softmax_forward(softmax_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; + ICHECK_EQ(node.GetInputs().size(), 2U); // Memory and compute description. - std::vector data_dims; - std::vector data_mds; - std::vector data_memories; + auto lhs_tr = GetInput(nid, 0); + auto rhs_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); - ICHECK_EQ(node.GetInputs().size(), 2U); - for (auto entry : node.GetInputs()) { - auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - auto dtype = dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype); - - data_dims.push_back(data_shape); - data_mds.push_back(data_md); - data_memories.push_back(BindDNNLMemory(entry, data_md)); - } - ICHECK(data_dims[0] == data_dims[1]); - auto out_md = data_mds[0]; - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, out_md); + lhs_tr = lhs_tr.Broadcast(dst_tr.dims()); + rhs_tr = rhs_tr.Broadcast(dst_tr.dims()); - auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); + auto binary_desc = dnnl::binary::desc(algo, lhs_tr.desc(), rhs_tr.desc(), dst_tr.desc()); auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); - auto binary = dnnl::binary(binary_prim_desc); - net_.push_back(binary); - net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, - {DNNL_ARG_SRC_1, data_memories[1]}, - {DNNL_ARG_DST, out_memory}}); + Submit(dnnl::binary(binary_prim_desc), + {{DNNL_ARG_SRC_0, lhs_tr}, {DNNL_ARG_SRC_1, rhs_tr}, {DNNL_ARG_DST, dst_tr}}); + } + + template ::value, int> = 0> + T AttrConvert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stol(val[0]); + } + + template ::value, int> = 0> + T AttrConvert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stof(val[0]); + } + + template ::value, int> = 0> + T AttrConvert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return val[0]; + } + + template >::value, int> = 0> + T AttrConvert(std::vector val) { + T res; + for (const auto& el : val) res.push_back(AttrConvert({el})); + return res; + } + + /*! + * \brief Helper to extract node attribute with ability to specify default value and result type. + */ + template + const T GetNodeAttr(const json::JSONGraphNode& node, std::string name, + std::vector def = {}) { + auto attr = node.HasAttr(name) ? node.GetAttr>(name) : def; + return AttrConvert(attr); } - // Read from DNNL memory (+offset) and write to the handle. - inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* src = static_cast(mem.get_data_handle()); - std::copy(src + offset, src + offset + size, static_cast(handle)); + TensorRequisite GetInput(const size_t& nid, const int idx) { + if (idx == -1) return {}; // -1 reserved value for empty input. + + const JSONGraphNode& node = nodes_[nid]; + + ICHECK_LT(idx, node.GetInputs().size()); + auto data_entry = node.GetInputs()[idx]; + + auto shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + auto dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; + auto eid = node_row_ptr_[data_entry.id_] + data_entry.index_; + auto const_dl_tensor = data_entry_[eid]; + + auto desc = MakePlainDesc(shape, dtype); + + TensorRequisite res; + if (const_dl_tensor) { + ICHECK(const_dl_tensor->data); + ICHECK(const_dl_tensor->strides == nullptr); + auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data); + res = TensorRequisite::AsIs(mem, eid); + } else { + res = TensorRequisite::AsIs(desc, eid); + } + return res; } - // Read from the handle and write to DNNL memory (+offset). - inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* dst = static_cast(mem.get_data_handle()); - std::copy(reinterpret_cast(handle), reinterpret_cast(handle) + size, - dst + offset); + TensorRequisite GetOutput(const size_t& nid, const int idx) { + if (idx == -1) return {}; // -1 reserved value for empty input. + + const JSONGraphNode& node = nodes_[nid]; + + ICHECK_LT(idx, node.GetNumOutput()); + auto shape = node.GetOpShape()[idx]; + auto dtype = node.GetOpDataType()[idx]; + auto eid = node_row_ptr_[nid] + static_cast(idx); + + ICHECK(data_entry_[eid] == nullptr); + auto desc = MakePlainDesc(shape, dtype); + + return TensorRequisite::AsIs(desc, eid).Backward(); } - // Generate DNNL memory description and infer the data layout by the given shape. - inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) { - dnnl::memory::desc data_md; - switch (shape.size()) { - case 2: - data_md = dnnl::memory::desc({shape, dtype, tag::ab}); - break; - case 3: - data_md = dnnl::memory::desc({shape, dtype, tag::abc}); - break; - case 4: - data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); - break; - case 5: - data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); - break; - default: - LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); - break; + /*! \brief Helper function to register primitive into execution queue */ + void Submit(const dnnl::primitive& prim, + const std::unordered_map& tr_args) { + // Register all provided TR arguments + std::unordered_map prim_arg_id; + TensorRegistry::ActionQue post_prim_actions; + for (const auto& kvp : tr_args) { + const auto& key = kvp.first; + const auto& tr = kvp.second; + + if (!tr.defined()) continue; // empty arg is admitted. Just skip it + auto arg_id = tensor_registry_.Register(tr, tr.IsReversed() ? &post_prim_actions : &net_); + prim_arg_id[key] = arg_id; } - return data_md; + + // Register main primitive + net_.push_back({prim, prim_arg_id}); + + // Register post actions + net_.insert(net_.end(), post_prim_actions.begin(), post_prim_actions.end()); } + uint32_t GenUniqueEid() { return next_unique_eid_offset_++; } + /* The dnnl engine. */ dnnl::engine engine_; /* The dnnl stream. */ dnnl::stream stream_; /* The network layers that are represented in dnnl primitives. */ - std::vector net_; - /* The memory that is consumed by arguments. */ - std::vector> net_args_; - /* The entry ID to its corresponding output memory. */ - std::unordered_map> entry_out_mem_; + TensorRegistry::ActionQue net_; + /* Storage for all memory objects */ + TensorRegistry tensor_registry_; + /* Generator of new unique eid which doesn't match with existing data entry */ + uint32_t next_unique_eid_offset_; + /* Map of Run arg idx to corresponding eid */ + std::vector run_arg_eid_; }; runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h new file mode 100644 index 0000000000000..d02ceff5de823 --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h @@ -0,0 +1,720 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl_tensor_requisite.cc + * \brief Helper TR wrapper to simplify tensors processing + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "dnnl_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace utils; + +/*! + * \brief Helper object to simplify tensor transformation description. + * + * Allow to specify original source tensor and future actions which should be applied to it. + * Can be treated as sequence of reordering or reinterpretation of original source tensor. + * Finally TR can be solved as proper interpretation of source memory buffer, or sequence of + * dnnl::reorder operators which will provide desired data. + * + * \note Empty TR object allow any manipulation. Empty TR will be returned. + * + * \sa TensorRegistry + * + * Example: + * \code + * dnnl::memory src_mem = ...; // 5D tensor, shape {5, 2, 128, 128, 8} + * + * // Construct TR + * auto tr = TensorRequisite.AsIs(src_mem, eid); // 5D + * + * // describe sequence of layout transformation + * tr = tr.TreatAs("ABCD8b"); // 4D + * tr = tr.Permute({0, 2, 3, 1}); // Permute axes NCHW -> NHWC + * tr = tr.Crop({1, 128, 128, 16}, {0, 0, 0}); // extract first batch element + * tr = tr.Squeeze(); // 1D + * + * // register TR + * TensorRegistry t_reg; + * auto t_id = t_reg.register(tr); + * + * // Get final dnnl::memory object + * auto solver = t_reg.MakeSolver(ext_tensor_provider); + * auto mem = solver(t_id); + * \endcode + * + */ +class TensorRequisite { + public: + using Tid = uint32_t; + static constexpr Tid kUndefinedTid = std::numeric_limits::max() - 1; + + /*! \brief Empty constructor */ + TensorRequisite() {} + + /*! \brief Construct TR on top of existing memory object */ + static TensorRequisite AsIs(const dnnl::memory& mem, Tid id = kUndefinedTid) { + auto res = AsIs(mem.get_desc(), id); + if (mem.get_data_handle() != nullptr) res.mem_ = mem; + return res; + } + + /*! \brief Construct TR on top of existing memory descriptor object */ + static TensorRequisite AsIs(const dnnl::memory::desc& desc, Tid id = kUndefinedTid) { + return {desc, {}, false, {}, id, false}; + } + + /*! \brief return logical shape of tensor */ + dnnl::memory::dims dims() const { return t_desc_.dims(); } + + /*! \brief return data type of tensor */ + dnnl::memory::data_type data_type() const { return t_desc_.data_type(); } + + /*! \brief return tensor desc */ + dnnl::memory::desc desc() const { return t_desc_; } + + /*! \brief Make TR with backward dataflow */ + TensorRequisite Backward() const { + if (!defined()) return *this; + ICHECK(orig_ == nullptr); + return {t_desc_, orig_, reinterpret_, mem_, eid_, true}; + } + + /*! \brief Produce TR with permuted axes */ + TensorRequisite Permute(const std::vector& permutation) const { + if (!defined()) return *this; // nothing for empty TR + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.permute_axes(permutation); + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with reinterpret data of original tr */ + TensorRequisite Reshape(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(shape); + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with broadcasted values */ + TensorRequisite Broadcast(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + ICHECK(!reverse_data_flow_); + + auto orig = std::make_shared(*this); + + // numpy like broadcast + auto extended_dims = t_desc_.dims(); + auto one_filled = dnnl::memory::dims(shape.size() - extended_dims.size(), 1); + extended_dims.insert(extended_dims.begin(), one_filled.begin(), one_filled.end()); + auto desc = t_desc_.reshape(extended_dims); + for (size_t i = 0; i < extended_dims.size(); i++) { + if (extended_dims[i] == shape[i]) continue; + ICHECK(extended_dims[i] == 1); + ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]); + + desc.data.dims[i] = shape[i]; + desc.data.padded_dims[i] = shape[i]; + desc.data.format_desc.blocking.strides[i] = 0; + } + + // reinterpret memory buffer with new strides + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with sub memory view (ROI) */ + TensorRequisite Crop(const dnnl::memory::dims& shape, const dnnl::memory::dims& offset) const { + if (!defined()) return *this; // nothing for empty TR + + ICHECK_EQ(shape.size(), t_desc_.dims().size()); + ICHECK_EQ(offset.size(), t_desc_.dims().size()); + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.submemory_desc(shape, offset, /*allow_empty=*/true); + + // Originally DNNL implementation is very limited. Let's slightly enhance it. + if (!desc && t_desc_.data.format_kind == dnnl_blocked) { + bool offset_is_zero = + std::all_of(offset.begin(), offset.end(), [](auto el) { return el == 0; }); + + dnnl::memory::dims block_sizes(t_desc_.dims().size(), 1); + for (int i = 0; i < t_desc_.data.format_desc.blocking.inner_nblks; i++) + block_sizes[t_desc_.data.format_desc.blocking.inner_idxs[i]] *= + t_desc_.data.format_desc.blocking.inner_blks[i]; + + bool shape_reduction_less_than_block = true; + for (int i = 0; i < t_desc_.data.ndims; i++) { + shape_reduction_less_than_block &= t_desc_.data.dims[i] - shape[i] < block_sizes[i]; + } + + // This is auto padded case. Just update dims value. + if (offset_is_zero && shape_reduction_less_than_block) { + desc = t_desc_; + std::copy(shape.begin(), shape.end(), desc.data.dims); + } + } + + ICHECK(desc); + + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with squeeze shape */ + TensorRequisite Squeeze(const dnnl::memory::dims& dims_to_squeeze = {}) const { + if (!defined()) return *this; // nothing for empty TR + + dnnl::memory::dims squeezed_dims; + if (dims_to_squeeze.empty()) { + for (auto d : t_desc_.dims()) + if (d != 1) squeezed_dims.push_back(d); + } else { + for (size_t i = 0; i < t_desc_.dims().size(); i++) + if (std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) == dims_to_squeeze.end()) + squeezed_dims.push_back(t_desc_.dims()[i]); + } + + if (squeezed_dims.empty()) squeezed_dims = {1}; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(squeezed_dims); + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with specified layout descriptor */ + TensorRequisite RequestLayout(dnnl::memory::desc desc) const { + if (!defined()) return *this; // nothing for empty TR + + // If it's the same desc just return self + if (desc == t_desc_) return *this; + + ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not compatible with " + "presented shape"; + + auto orig = std::make_shared(*this); + return {desc, orig, false, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Define which logical dims ordering is default for particular layout string. */ + static std::string DefaultLogicLayoutFor(const std::string& layout) { + // Rank is all non digit marked dims + auto it = layout.begin(); + while (it != layout.end() && !std::isdigit(*it)) it++; + int rank = std::distance(layout.begin(), it); + + static const std::vector sparse_dims = {"W", "HW", "DHW"}; + if (layout.find("N") != std::string::npos) return "NC" + sparse_dims[rank - 3]; + if (layout.find("G") != std::string::npos) return "GOI" + sparse_dims[rank - 4]; + if (layout.find("O") != std::string::npos) return "OI" + sparse_dims[rank - 3]; + + LOG(FATAL) << "Unknown layout " << layout << "There is no default scheme to handle it"; + return {}; + } + + /*! + * \brief Treat TR shape as described in layout string. + * + * Blocked dimensions will be concatenated and put into proper shape position corresponding to . + * resulting_layout_logic argument. If desired logic layout was not provided it will be deduced + * automatically based on some internal heuristics. + * + * Limitation 1. Blocking dims should be dense. Dims marked with digits use natural strides. + * Limitation 2. Blocking dims are innermost. Dims marked like 8c, 4o goes after regular + * dimensions. NC8cHW4h4cD is not valid tensor in terms of DNNL. And cannot be + * achieved with memory reinterpretation, so data copy is required. Proper layout + * looks like NCHWD_8c4h4c, first part is outer dims, second digits marked part is + * innermost. + */ + TensorRequisite TreatAs(const std::string& layout, std::string desired_logic_layout = "") const { + if (desired_logic_layout.empty()) desired_logic_layout = DefaultLogicLayoutFor(layout); + + const auto origin_dims = dims(); + + // split layout string to tokens {size, tag} like {16, 'C'}, {4, 'O'} + std::vector> layout_tokens; + for (auto it = layout.begin(); it != layout.end();) { + auto start = it; + while (std::isdigit(*it)) it++; + int blk_size = start == it ? -1 : std::stoi(std::string{start, it}); + layout_tokens.push_back({blk_size, std::toupper(*it)}); + it++; + } + + // check applicability of layout + auto it = layout_tokens.begin(); + while (it != layout_tokens.end() && it->first == -1) it++; + int rank = std::distance(layout_tokens.begin(), it); + while (it != layout_tokens.end()) { + ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be innermost. " + << "But received layout is " << layout; + it++; + } + + ICHECK_EQ(layout_tokens.size(), origin_dims.size()); + ICHECK_EQ(rank, desired_logic_layout.size()) << layout; + + std::vector> outermost_tokens(layout_tokens.begin(), + layout_tokens.begin() + rank); + std::vector> innermost_tokens(layout_tokens.begin() + rank, + layout_tokens.end()); + // define dim resulting dim positions + std::map dim_position_by_tag; + for (size_t i = 0; i < desired_logic_layout.size(); i++) + dim_position_by_tag[std::toupper(desired_logic_layout[i])] = i; + + // Construct resulting desc by modifying original one + dnnl::memory::desc res_desc = t_desc_; + + memset(&res_desc.data.format_desc.blocking, 0, sizeof(res_desc.data.format_desc.blocking)); + std::fill(res_desc.data.dims, res_desc.data.dims + DNNL_MAX_NDIMS, 0); + std::fill(res_desc.data.padded_dims, res_desc.data.padded_dims + DNNL_MAX_NDIMS, 0); + + res_desc.data.ndims = rank; + res_desc.data.format_desc.blocking.inner_nblks = innermost_tokens.size(); + + auto res_dims = res_desc.data.dims; + auto res_strides = res_desc.data.format_desc.blocking.strides; + auto res_inner_blks = res_desc.data.format_desc.blocking.inner_blks; + auto res_inner_idxs = res_desc.data.format_desc.blocking.inner_idxs; + + std::fill(res_dims, res_dims + rank, 1); + + int orig_dim_idx = 0; + for (const auto& p : outermost_tokens) { + auto tag = p.second; + auto dim_size = origin_dims[orig_dim_idx]; + + auto result_dim_position = dim_position_by_tag[tag]; + res_dims[result_dim_position] *= dim_size; + res_strides[result_dim_position] = t_desc_.data.format_desc.blocking.strides[orig_dim_idx]; + orig_dim_idx++; + } + for (const auto& p : innermost_tokens) { + auto tag = p.second; + auto dim_size = origin_dims[orig_dim_idx]; + auto result_dim_position = dim_position_by_tag[tag]; + ICHECK_EQ(p.first, dim_size) + << "Blocking layout is not applicable to tensor with shape: " << origin_dims + << ". Requested layout is " << layout; + + res_dims[result_dim_position] *= dim_size; + *res_inner_blks++ = dim_size; + *res_inner_idxs++ = result_dim_position; + orig_dim_idx++; + } + + // Assume tensor is dense. There is no additional padding. + std::copy(res_desc.data.dims, res_desc.data.dims + rank, res_desc.data.padded_dims); + + if (t_desc_ == res_desc) return *this; + + auto orig = std::make_shared(*this); + return {res_desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! + * \brief Produce TR with unspecified layout. + * + * Cannot be registered in TensorRegistry. Only for querying DNNL for preferred layouts. + */ + TensorRequisite LayoutAny() const { + auto orig = std::make_shared(*this); + // Recreate tensor desc with layout 'any' + dnnl::memory::desc any_desc{t_desc_.dims(), t_desc_.data_type(), dnnl::memory::format_tag::any}; + return {any_desc, orig, false, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Check is TR is constant. */ + bool IsConstant() const { + if (orig_) return orig_->IsConstant(); + return mem_.operator bool(); + } + + /*! \brief Check is tensor is scalar. */ + bool IsScalar() const { return t_desc_.dims().size() == 1 && t_desc_.dims()[0] == 1; } + + /*! \brief Return const data memory if available. */ + dnnl::memory GetConstData() const { + if (mem_) return mem_; + if (!orig_) return {}; + + if (auto orig_const_data = orig_->GetConstData()) { + if (reinterpret_) { + return {t_desc_, orig_const_data.get_engine(), orig_const_data.get_data_handle()}; + } else { + auto eng = orig_const_data.get_engine(); + auto res = dnnl::memory{t_desc_, eng}; + dnnl::reorder(orig_const_data, res).execute(dnnl::stream(eng), orig_const_data, res); + return res; + } + } + return {}; + } + + /*! + * \brief Return const data memory in form of vector. + * + * Same as GetConstData but use std::vector instead of dnnl::memory. Works only for 1D tensor + * and scalar TRs. Useful for specification of 1D DNNL attributes like zero_point or + * per_channel_scale + */ + template + std::vector GetConstDataLikeVec() const { + auto const_data = GetConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::DnnlDType()); + ICHECK(desc.dims().size() == 1); + + auto size = desc.get_size() / sizeof(T); + auto ptr = static_cast(const_data.get_data_handle()); + + return std::vector(ptr, ptr + size); + } + + /*! \brief Get value of constant scalar tensor if possible. */ + template + T GetConstScalarData() const { + ICHECK(IsConstant()); + ICHECK(IsScalar()); + auto const_data = GetConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::DnnlDType()); + + auto ptr = static_cast(const_data.get_data_handle()); + return *ptr; + } + + /*! \brief Check if tensor is not empty. */ + bool defined() const { return !t_desc_.is_zero(); } + + /*! \brief Same as defined */ + operator bool() const { return defined(); } + + /*! + * \brief Check if tensor represent a reversed data flow. + * Useful for describing output processing + */ + bool IsReversed() const { return reverse_data_flow_; } + + private: + TensorRequisite(const dnnl::memory::desc& t_desc, const std::shared_ptr& orig, + bool reinterpret, const dnnl::memory& const_mem, uint32_t eid, + bool reverse_data_flow) + : t_desc_(t_desc), + orig_(orig), + reinterpret_(reinterpret), + mem_(const_mem), + eid_(eid), + reverse_data_flow_(reverse_data_flow) { + if (mem_) ICHECK(!orig_ && !reverse_data_flow_ && eid_ == kUndefinedTid); + if (eid_ != kUndefinedTid) ICHECK(!orig_); + } + + /* Descriptor of particular tensor */ + dnnl::memory::desc t_desc_ = {}; + /* Parent TR object which is referred from this TR */ + std::shared_ptr orig_ = {}; + /* Flag to specify which action should be done with orig TR, reordering or reinterpretation */ + bool reinterpret_ = false; + /* Const memory object if available */ + dnnl::memory mem_ = {}; + /* Entry ID of tensor if available */ + uint32_t eid_ = kUndefinedTid; + + /* + * Flag to describe reverse data flow case + * All operation on queue will be executed in reverse order. Actual for dst tensor description + */ + bool reverse_data_flow_ = false; + + friend class TensorRegistry; +}; + +/*! + * \brief The registry of tensors. Implement matching of provided TRs and real memory buffers. + * + * Registration of TR performed by calling method Register(), which will return ArgId object. + * ArgId can be mapped to real memory via memory solver created by method MakeSolver(). + */ +class TensorRegistry { + private: + enum ArgReqFlag { + CONST, /// < Constant tensor. ExecutionCTX independent + TMP_STORAGE, /// < Intermediate tensors. Stored inside TensorRegistry. Inaccessible outside + EXT_EID, /// < External data. Input or Output. + }; + + public: + struct ArgId { + TensorRegistry::ArgReqFlag flag_; + uint32_t idx_; + }; + + using Action = std::tuple>; + using ActionQue = std::vector; + using DLTensorProvider = std::function; + using MemSolver = std::function; + + TensorRegistry() = default; + TensorRegistry(const dnnl::engine& eng, const std::set& ext_io_eid) + : tmp_mem_collection_(1), ext_io_eid_(ext_io_eid), eng_(eng), stream_(eng) {} + + /*! + * \brief Register TR to registry + * + * Resolution of TR may lead to introduction of intermediate memory buffers and additional + * transformation actions which should be performed before or after usage of corresponding memory + * buffer. Additional actions will be append to provided actions queue. Corresponding to + * tr.IsReversed() value actions should be executed before or after usage of resulting ArgId. + * + * \param tr tensor requisite sequence to register + * \param action resulting action queue. If TR resolution is required execution of some + * transformation actions they will be put here + * \return associated ArgId. Should be used as argument for MemSolver. + */ + ArgId Register(const TensorRequisite& tr, ActionQue* action) { + // 1) Constant tensor. Direct reference + if (auto const_data = tr.GetConstData()) { + auto idx = const_mem_collection_.size(); + const_mem_collection_.push_back(const_data); + return MakeArgReq(ArgReqFlag::CONST, static_cast(idx)); + } + + // 2) EID mapped tensor. Direct reference + if (tr.eid_ != TensorRequisite::kUndefinedTid) { + if (ext_io_eid_.count(tr.eid_) == 0) { // Not IO tensor, means it's intermediate + if (eid2idx_tmp_.count(tr.eid_)) { + auto idx = eid2idx_tmp_.at(tr.eid_); + return MakeArgReq(ArgReqFlag::TMP_STORAGE, idx); + } else { + // register himself + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(tr.t_desc_); + eid2idx_tmp_[tr.eid_] = idx; + return MakeArgReq(ArgReqFlag::TMP_STORAGE, static_cast(idx)); + } + } else { + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({tr.eid_, tr.t_desc_}); + return MakeArgReq(ArgReqFlag::EXT_EID, static_cast(idx)); + } + } + + // 3) Tensors with transform actions + if (tr.orig_) { + // recursive register of orig TR + auto orig_arg_req = Register(*tr.orig_, action); + if (tr.reinterpret_) { + return RegisterReinterpret(orig_arg_req, tr.t_desc_); + } else { + return RegisterReorder(orig_arg_req, tr.t_desc_, tr.reverse_data_flow_, action); + } + } + + // 4) Scratchpad + ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::kUndefinedTid); + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(tr.t_desc_); + tmp_mem_mapping_[idx] = 0; // zero position tmp mem object is reserved for scratchpads + + auto scratchpad_size = tr.t_desc_.get_size(); + auto glob_scratchpad_size = tmp_mem_collection_[0].get_size(); + if (scratchpad_size > glob_scratchpad_size) { + tmp_mem_collection_[0] = + dnnl::memory::desc({static_cast(scratchpad_size)}, + dnnl::memory::data_type::u8, dnnl::memory::format_tag::a); + } + return MakeArgReq(TMP_STORAGE, static_cast(idx)); + } + + /*! + * \brief Construct memory solver for all registered TRs. + * \param ext_provider callback to resolve external IO buffers + * \return memory solver object to match ArgId to dnnl::memory objects + */ + MemSolver MakeSolver(const DLTensorProvider& ext_provider) const { + return MemSolverImpl(eng_, ext_provider, const_mem_collection_, ext_mem_collection_, + tmp_mem_collection_, tmp_mem_mapping_); + } + + private: + ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) { + switch (src_ar.flag_) { + case TMP_STORAGE: { + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(desc); + tmp_mem_mapping_[idx] = src_ar.idx_; + return MakeArgReq(TMP_STORAGE, idx); + } + case EXT_EID: { + auto ext_req = ext_mem_collection_[src_ar.idx_]; + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({ext_req.first, desc}); + return MakeArgReq(EXT_EID, idx); + } + default: + LOG(FATAL) << "Unknown case"; + } + return {}; + } + + ArgId RegisterReorder(ArgId src_ar, const dnnl::memory::desc& desc, bool reverse_data_flow, + ActionQue* action) { + ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID); + + auto src_desc = src_ar.flag_ == TMP_STORAGE ? tmp_mem_collection_[src_ar.idx_] + : ext_mem_collection_[src_ar.idx_].second; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(desc); + auto dst_ar = MakeArgReq(TMP_STORAGE, idx); + + // reorder action submit + if (reverse_data_flow) { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, desc, eng_, src_desc); + action->insert(action->begin(), + {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, dst_ar}, {DNNL_ARG_TO, src_ar}}}); + } else { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, src_desc, eng_, desc); + action->push_back( + {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO, dst_ar}}}); + } + return dst_ar; + } + /*! \brief Implementation of memory solver */ + class MemSolverImpl { + public: + MemSolverImpl(const dnnl::engine& eng, const DLTensorProvider& ext_data_provider, + const std::vector& const_mems, + const std::vector>& ext_mems, + const std::vector& tmp_mem_descs, + const std::map& tmp_mem_mapping) + : eng_(eng), + ext_data_provider_(ext_data_provider), + const_mems_(const_mems), + ext_mems_(ext_mems) { + // Construct temp memory objects on the fly. While we have no scratchpads + // support on VM/GraphExecutor level. + tmp_mems_.resize(tmp_mem_descs.size()); + for (size_t i = 0; i < tmp_mem_descs.size(); i++) { + auto found = tmp_mem_mapping.find(i); + + if (found != tmp_mem_mapping.end()) { + auto reuse_hdl = tmp_mems_[found->second].get_data_handle(); + tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_, reuse_hdl); + } else { + tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_); + } + } + } + + /*! \brief Find memory object associated with provided ArgId */ + dnnl::memory operator()(const ArgId& ar) const { + switch (ar.flag_) { + case CONST: + return const_mems_.at(ar.idx_); + case TMP_STORAGE: + return tmp_mems_.at(ar.idx_); + case EXT_EID: { + auto eid_and_desc = ext_mems_.at(ar.idx_); + auto eid = eid_and_desc.first; + auto desc = eid_and_desc.second; + + auto ext_dl_tensor = ext_data_provider_(eid); + ICHECK(ext_dl_tensor->data); + return dnnl::memory{desc, eng_, ext_dl_tensor->data}; + } + } + return {}; + } + + private: + const dnnl::engine& eng_; + const DLTensorProvider& ext_data_provider_; + const std::vector& const_mems_; + const std::vector>& ext_mems_; + std::vector tmp_mems_; + }; + + ArgId MakeArgReq(ArgReqFlag flag, uint32_t idx) { return {flag, idx}; } + + /* Collection of const memory objects. */ + std::vector const_mem_collection_; + + /* Collection of intermediate memory descriptors. Zero position is reserved for scratchpads. */ + std::vector tmp_mem_collection_; + + /* Mapping of some temp buffer on previously registered. */ + std::map tmp_mem_mapping_; + + /* Collection of external_intermediate memory objects. + * first - eid of external buffer to ask + * second - t_desc describes how to treat external buffer */ + std::vector> ext_mem_collection_; + + /* Map of eid to index of temp buffer in tmp_mem_collection_ */ + std::unordered_map eid2idx_tmp_; + + /* List of external eid */ + std::set ext_io_eid_; + + /* Engine of all tensors existing in this registry */ + dnnl::engine eng_; + + /* Execution stream use to reorder const data */ + dnnl::stream stream_; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc b/src/runtime/contrib/dnnl/dnnl_utils.cc index 7e79f1c939cfe..23992209f2ad5 100644 --- a/src/runtime/contrib/dnnl/dnnl_utils.cc +++ b/src/runtime/contrib/dnnl/dnnl_utils.cc @@ -23,11 +23,14 @@ #include "dnnl_utils.h" +#include "tvm/runtime/logging.h" + namespace tvm { namespace runtime { namespace contrib { -using dt = dnnl::memory::data_type; -dt dtype_dl2dnnl(DLDataType dltype) { + +dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype) { + using dt = dnnl::memory::data_type; dt dnnl_type = dt::undef; if (dltype.code == DataType::TypeCode::kFloat) { if (dltype.bits == 16) { @@ -51,6 +54,23 @@ dt dtype_dl2dnnl(DLDataType dltype) { } return dnnl_type; } + +dnnl::memory::dims shape_dl2dnnl(const std::vector& shape) { + if (shape.empty()) return {1}; // DNNL scalar representation is 1D tensor + return shape; +} + +dnnl::memory::desc MakePlainDesc(const std::vector& shape, DLDataType dltype) { + auto dnnl_shape = shape_dl2dnnl(shape); + auto dnnl_dtype = dtype_dl2dnnl(dltype); + + auto dnnl_plain_strides = dnnl::memory::dims(dnnl_shape.size(), 1); + for (int i = dnnl_shape.size() - 2; i >= 0; i--) + dnnl_plain_strides[i] = dnnl_plain_strides[i + 1] * dnnl_shape[i + 1]; + + return {dnnl_shape, dnnl_dtype, dnnl_plain_strides}; +} + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h b/src/runtime/contrib/dnnl/dnnl_utils.h index 4fb236f96f8b1..a598b6704450f 100644 --- a/src/runtime/contrib/dnnl/dnnl_utils.h +++ b/src/runtime/contrib/dnnl/dnnl_utils.h @@ -18,16 +18,23 @@ */ /*! - * \file src/runtime/contrib/dnnl/dnnl_utils.h - * \brief utils for DNNL. + * \file src/runtime/contrib/dnnl/dnnl_utils.cc + * \brief Some DNNL specific utility functions */ #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ -#include +#include +#include +#include +#include -#include "dnnl.hpp" +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "tvm/runtime/data_type.h" namespace tvm { namespace runtime { @@ -40,7 +47,90 @@ namespace contrib { */ dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype); +/*! + * \brief Converter TVM shape to DNNL dims + * \param shape tvm shape + * \return dims in terms of dnnl + */ +dnnl::memory::dims shape_dl2dnnl(const std::vector& shape); + +/*! + * \brief Construct plain tensor descriptor + * \param shape provided shape + * \param dltype provided data type + * \return resulting plain tensor desc + */ +dnnl::memory::desc MakePlainDesc(const std::vector& shape, DLDataType dltype); + +namespace utils { + +/*! \brief Pretty printer util for shape */ +inline std::ostream& operator<<(std::ostream& o, const dnnl::memory::dims& dims) { + o << "["; + auto d = dims.begin(); + if (d != dims.end()) o << *d++; + while (d != dims.end()) o << "," << *d++; + o << "]"; + return o; +} + +/*! \brief Pretty printer util for data type */ +inline std::ostream& operator<<(std::ostream& o, const dnnl::memory::data_type& type) { + std::string name = "undef"; + switch (type) { + case dnnl::memory::data_type::undef: + name = "undef"; + break; + case dnnl::memory::data_type::f32: + name = "fp32"; + break; + case dnnl::memory::data_type::f16: + name = "fp16"; + break; + case dnnl::memory::data_type::bf16: + name = "bf16"; + break; + case dnnl::memory::data_type::s32: + name = "i32"; + break; + case dnnl::memory::data_type::s8: + name = "i8"; + break; + case dnnl::memory::data_type::u8: + name = "u8"; + break; + } + o << name; + return o; +} + +/*! \brief Converter data type template arg to runtime object */ +template +inline dnnl::memory::data_type DnnlDType(); + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::s32; +} + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::f32; +} + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::u8; +} + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::s8; +} + +} // namespace utils } // namespace contrib } // namespace runtime } // namespace tvm + #endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ From 4f5ab57d348e97b707d0707f9272cebe03a79777 Mon Sep 17 00:00:00 2001 From: ChunPing Chung Date: Fri, 3 Jun 2022 00:28:38 +0800 Subject: [PATCH 020/181] [Frontend][ONNX] Fix softmax converter when input shape is dynamic (#11507) * [Frontend][ONNX] Fix softmax converter when input shape is dynamic * [Frontend][ONNX] mark dynamic softmax tests as xfailed with cuda --- python/tvm/relay/frontend/onnx.py | 2 ++ tests/python/frontend/onnx/test_forward.py | 37 ++++++++++++++++++---- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 30e8188a8312c..997aa6240e9e8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2420,6 +2420,8 @@ def _impl_v1(cls, inputs, attr, params): axis += ndim if axis == 0: reshape_shape = [-1] + elif axis == ndim - 1: + return _op.nn.softmax(inputs[0], axis=axis) else: axis_val = [in_shape[i] for i in range(axis)] reshape_shape = [np.prod(axis_val)] + [-1] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dbc5147e20300..c4cd93aa7d9b0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1589,26 +1589,45 @@ def test_upsample3d_trilinear(target, dev): tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) +# TODO: Fix softmax with dynamic input on cuda and enable this test +@tvm.testing.known_failing_targets("cuda") @tvm.testing.parametrize_targets def test_softmax(target, dev): - def verify_softmax(inshape, axis): + def verify_softmax(inshape, axis, opset=None, dynamic=False): opname = "Softmax" - indata = np.random.uniform(size=inshape).astype(np.float32) outshape = inshape - y = helper.make_node(opname, ["in"], ["out"]) + node_list = [] + input_node_list = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(inshape))] + output_node_list = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))] + input_list = [np.random.uniform(size=inshape).astype(np.float32)] + softmax_inputs = ["in"] + + if dynamic: + input_node_list.append( + helper.make_tensor_value_info("shape", TensorProto.INT64, [len(inshape)]) + ) + input_list.append(np.asarray(inshape)) + reshape_node = helper.make_node("Reshape", ["in", "shape"], ["dynamic_in"]) + softmax_inputs[0] = "dynamic_in" + node_list += [reshape_node] + + y = helper.make_node(opname, softmax_inputs, ["out"]) if axis is not None: axis_attr = helper.make_attribute("axis", axis) y.attribute.append(axis_attr) + node_list.append(y) graph = helper.make_graph( - [y], + node_list, opname + "_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))], + inputs=input_node_list, + outputs=output_node_list, ) model = helper.make_model(graph, producer_name=opname + "_test") - verify_with_ort_with_inputs(model, [indata], target=target, dev=dev) + verify_with_ort_with_inputs( + model, input_list, use_vm=True, opset=opset, target=target, dev=dev + ) verify_softmax((1, 10), None) verify_softmax((1, 10), 1) @@ -1616,6 +1635,10 @@ def verify_softmax(inshape, axis): verify_softmax((1, 2, 3, 10), 2) verify_softmax((1, 2, 3, 4, 10), 3) verify_softmax((1, 2, 3, 4, 10), 4) + verify_softmax((1, 10), -1, dynamic=True) + verify_softmax((1, 2, 3, 10), -1, dynamic=True) + verify_softmax((1, 10), -1, opset=8, dynamic=True) + verify_softmax((1, 2, 3, 10), -1, opset=8, dynamic=True) @tvm.testing.parametrize_targets From 480fa744eb66a2c6013d43ee46778d02b905ca19 Mon Sep 17 00:00:00 2001 From: Jocelyn S Date: Thu, 2 Jun 2022 13:15:04 -0400 Subject: [PATCH 021/181] [Onnx] Round operator (#11446) * banker round op added based off tutorial * black'd onnx.py file * retriggering CI with empty commit due to autoscheduler test failure * removed youtube link in comments * retriggering CI due to test failure that passed locally --- python/tvm/relay/frontend/onnx.py | 21 +++++++++++++++++++-- tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 997aa6240e9e8..abfa5629d5534 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -5061,10 +5061,27 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Round(OnnxOpConverter): + """Operator converter for round op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + # Onnx round uses Banker's rounding which rounds .5 to the nearest even integer + + x = inputs[0] + half = _expr.const(0.5, dtype="float32") + one = _expr.const(1, dtype="float32") + two = _expr.const(2, dtype="float32") + + rounded = _op.ceil(x - half) + bankers_mask = one - (_op.ceil(x + half) - _op.floor(x + half)) + non_even = _op.abs(_op.mod(rounded, two)) + return rounded + (bankers_mask * non_even) + + # compatible operators that do NOT require any conversion. _identity_list = [] - # _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted @@ -5109,7 +5126,7 @@ def _get_convert_map(opset): "Reciprocal": Reciprocal.get_converter(opset), "Floor": Renamer("floor"), "Ceil": Renamer("ceil"), - "Round": Renamer("round"), + "Round": Round.get_converter(opset), "IsInf": IsInf.get_converter(opset), "IsNaN": Renamer("isnan"), "Sqrt": Renamer("sqrt"), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c4cd93aa7d9b0..ebaad9b4cb136 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5183,7 +5183,6 @@ def verify_eyelike(indata): "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", "test_rnn_seq_length", - "test_round", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_simple_rnn_batchwise", From 84eb78cbc4663d6f25ee5a7ead6a930eba02776b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 2 Jun 2022 10:47:29 -0700 Subject: [PATCH 022/181] [MetaSchedule] No explicit for spatial PrimFunc (#11534) --- .../parallel_vectorize_unroll.cc | 7 +- src/tir/schedule/analysis.h | 7 + src/tir/schedule/analysis/analysis.cc | 19 ++ ...schedule_rule_parallel_vectorize_unroll.py | 179 ++++++++++++++++++ 4 files changed, 211 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index c0e57a6d037a5..19758996e6080 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -26,6 +26,11 @@ bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) { return block_sref->parent == nullptr; } +bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) { + return IsSpatialPrimFunc( + GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); +} + } // namespace tir } // namespace tvm @@ -60,7 +65,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent)); } // Unroll - if (!unroll_max_steps.empty()) { + if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; Array probs(n, FloatImm(DataType::Float(64), prob)); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 0574cfefadb6f..5adc4f8f1b30a 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -625,6 +625,13 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); */ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Checks if all the blocks in the PrimFunc is spatial + * \param func The PrimFunc to be checked + * \return A boolean indicating whether all the blocks in the PrimFunc is spatial + */ +bool IsSpatialPrimFunc(const PrimFunc& func); + /*! * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. * \param self The schedule state. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 83ef6adae3b23..0f84dfef1135f 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1957,6 +1957,25 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref return total_unused_block_vars >= 1; } +bool IsSpatialPrimFunc(const PrimFunc& func) { + bool result = true; + PreOrderVisit(func->body, [&result](const ObjectRef& obj) { + if (result == false) { + return false; + } + if (const auto* block = obj.as()) { + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != IterVarType::kDataPar) { + result = false; + return false; + } + } + } + return true; + }); + return result; +} + std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { Array loops = tir::GetLoops(block_sref); diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index e57799f604b8a..85aa80eb3c82b 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +from tvm import meta_schedule as ms from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll from tvm.meta_schedule.testing.space_generation import check_trace @@ -61,6 +62,164 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# from tvm.script import tir as T +@tvm.script.ir_module +class PureSpatial: + @T.prim_func + def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") + T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") + T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") + T_sigmoid_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") + T_multiply = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") + T_reshape = T.alloc_buffer([8112, 80], dtype="float32") + T_strided_slice_with_axes_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32") + T_sigmoid_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32") + T_strided_slice_with_axes_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") + T_sigmoid_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") + T_multiply_1 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") + T_reshape_1 = T.alloc_buffer([2028, 80], dtype="float32") + T_strided_slice_with_axes_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32") + T_sigmoid_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32") + T_strided_slice_with_axes_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") + T_sigmoid_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") + T_multiply_2 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") + T_reshape_2 = T.alloc_buffer([507, 80], dtype="float32") + T_concat = T.alloc_buffer([10647, 80], dtype="float32") + T_transpose = T.alloc_buffer([80, 10647], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): + with T.block("T_strided_slice_with_axes"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) + T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) + T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] + for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): + with T.block("T_sigmoid"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4]) + T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): + with T.block("T_strided_slice_with_axes_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) + T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) + T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] + for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): + with T.block("T_sigmoid_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) + T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) + T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] + for i0, i1 in T.grid(8112, 80): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] + for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): + with T.block("T_strided_slice_with_axes_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) + T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) + T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] + for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): + with T.block("T_sigmoid_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4]) + T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): + with T.block("T_strided_slice_with_axes_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) + T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) + T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] + for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): + with T.block("T_sigmoid_3"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) + T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): + with T.block("T_multiply_1"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) + T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] + for i0, i1 in T.grid(2028, 80): + with T.block("T_reshape_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) + T.writes(T_reshape_1[ax0, ax1]) + T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] + for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): + with T.block("T_strided_slice_with_axes_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) + T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) + T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] + for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): + with T.block("T_sigmoid_4"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4]) + T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): + with T.block("T_strided_slice_with_axes_5"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) + T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) + T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] + for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): + with T.block("T_sigmoid_5"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) + T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): + with T.block("T_multiply_2"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) + T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] + for i0, i1 in T.grid(507, 80): + with T.block("T_reshape_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) + T.writes(T_reshape_2[ax0, ax1]) + T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] + for i0, i1 in T.grid(10647, 80): + with T.block("T_concat"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1]) + T.writes(T_concat[ax0, ax1]) + T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32") + for i0, i1 in T.grid(80, 10647): + with T.block("T_transpose"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_concat[ax1, ax0]) + T.writes(T_transpose[ax0, ax1]) + T_transpose[ax0, ax1] = T_concat[ax1, ax0] + for i0, i1, i2 in T.grid(1, 80, 10647): + with T.block("T_expand_dims"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_transpose[ax1, ax2]) + T.writes(T_expand_dims[ax0, ax1, ax2]) + T_expand_dims[ax0, ax1, ax2] = T_transpose[ax1, ax2] + + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks # fmt: on @@ -101,5 +260,25 @@ def test_parallel_vectorize_unroll(): check_trace(spaces, expected) +def test_parallel_vectorize_unroll_spatial(): + mod = PureSpatial + target = Target("llvm --num-cores=32") + ctx = _create_context( + mod=mod, + target=target, + rule=ms.schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, + max_vectorize_extent=-1, + unroll_max_steps=[1, 2, 4, 8, 16, 32, 64], + unroll_explicit=True, + ), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + trace = spaces[0].trace.simplified(remove_postproc=True) + assert not trace.insts + + if __name__ == "__main__": test_parallel_vectorize_unroll() + test_parallel_vectorize_unroll_spatial() From 3bee5cacd7da5295e42e99e92d1864a97c9ffe80 Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Thu, 2 Jun 2022 11:22:02 -0700 Subject: [PATCH 023/181] [ci][wip] Upload docs with folder structure to S3 (#11528) Keeping the files as-is lets us serve them from S3 + CloudFront Co-authored-by: driazati --- Jenkinsfile | 7 +++++-- jenkins/Test.groovy.j2 | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index b9175f06afdc5..334448a7ae24b 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,7 +45,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-05-31T16:54:56.997402 +// Generated at 2022-06-01T16:34:53.941462 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> @@ -2875,7 +2875,10 @@ stage('Test') { label: 'Upload artifacts to S3', ) - archiveArtifacts(artifacts: 'docs.tgz', fingerprint: true) + sh( + script: "aws s3 cp --no-progress _docs s3://${s3_prefix}/docs --recursive", + label: 'Upload docs to S3', + ) } } } diff --git a/jenkins/Test.groovy.j2 b/jenkins/Test.groovy.j2 index d86575c247c75..d219b47bc7929 100644 --- a/jenkins/Test.groovy.j2 +++ b/jenkins/Test.groovy.j2 @@ -266,7 +266,10 @@ stage('Test') { ) } {{ m.upload_artifacts(tag='docs', filenames=["docs.tgz"]) }} - archiveArtifacts(artifacts: 'docs.tgz', fingerprint: true) + sh( + script: "aws s3 cp --no-progress _docs s3://${s3_prefix}/docs --recursive", + label: 'Upload docs to S3', + ) } } } From a2f89c53cc761a9ef8fa918105486b81a539a02b Mon Sep 17 00:00:00 2001 From: apeskov Date: Thu, 2 Jun 2022 22:24:24 +0300 Subject: [PATCH 024/181] Restore integration test on Mac and Windows (#11538) Signed-off-by: Alexander Peskov --- tests/python/contrib/test_dnnl.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index fecd776d7065e..76e3f1c3a4055 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -17,6 +17,8 @@ import pytest import itertools import numpy as np +import sys +import subprocess import tvm from tvm import relay @@ -37,7 +39,21 @@ ids=["compile", "run"], ) -bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read() +_bf16_supported = None + + +def bf16_supported(): + global _bf16_supported + if _bf16_supported is None: + _bf16_supported = False + if sys.platform.startswith("darwin"): + cpu_info = subprocess.check_output("sysctl -a", shell=True).strip().decode() + for line in cpu_info.split("\n"): + if line.startswith("hw.optional.avx512f"): + _bf16_supported = bool(line.split(":", 1)[1]) + elif sys.platform.startswith("linux"): + _bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read() + return _bf16_supported def partition_for_dnnl(mod, params=None, alter_layout=True): @@ -150,7 +166,7 @@ def check_dnnl_used(mod, subgraph_num=None): (True, False, False), (True, True, False), ] - if test_bf16 and bf16_supported: + if test_bf16 and bf16_supported(): configs += [(True, False, True), (True, True, True)] for use_dnnl, alter_layout, use_bf16 in configs: result_key = ( From 03eefe0b41587fecb910f3543b0ddc1adeb4fcff Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Thu, 2 Jun 2022 12:43:06 -0700 Subject: [PATCH 025/181] [ci] Add @tvm-bot rerun (#11480) This adds a command to restart CI runs that have stopped (either from a failure, success, or abort) via GitHub comments addressed to tvm-bot: ``` @tvm-bot rerun ``` tvm-bot will then comment on the thread and send a request to Jenkins to restart CI. This does not restart GitHub Actions jobs though we may be able to add that in the future. Co-authored-by: driazati --- .github/workflows/{merge.yml => tvmbot.yml} | 11 +- tests/python/ci/sample_prs/pr10786-badci.json | 3 +- .../sample_prs/pr10786-changes-requested.json | 3 +- .../ci/sample_prs/pr10786-co-authors.json | 2 +- .../ci/sample_prs/pr10786-invalid-author.json | 3 +- .../python/ci/sample_prs/pr10786-merges.json | 2 +- .../ci/sample_prs/pr10786-missing-job.json | 2 +- .../ci/sample_prs/pr10786-nottriggered.json | 2 +- .../ci/sample_prs/pr10786-oldreview.json | 2 +- .../pr11244-unauthorized-comment.json | 2 +- .../ci/sample_prs/pr11267-no-review.json | 4 +- .../ci/sample_prs/pr11276-no-review.json | 157 ------------- ...o-recomment.json => pr11442-rerun-ci.json} | 12 +- tests/python/ci/test_mergebot.py | 66 ++++-- tests/scripts/git_utils.py | 22 ++ .../{github_mergebot.py => github_tvmbot.py} | 219 +++++++++++------- 16 files changed, 239 insertions(+), 273 deletions(-) rename .github/workflows/{merge.yml => tvmbot.yml} (62%) delete mode 100644 tests/python/ci/sample_prs/pr11276-no-review.json rename tests/python/ci/sample_prs/{pr11442-no-recomment.json => pr11442-rerun-ci.json} (95%) rename tests/scripts/{github_mergebot.py => github_tvmbot.py} (80%) diff --git a/.github/workflows/merge.yml b/.github/workflows/tvmbot.yml similarity index 62% rename from .github/workflows/merge.yml rename to .github/workflows/tvmbot.yml index efbada4b00a46..c9d2cf71e6a70 100644 --- a/.github/workflows/merge.yml +++ b/.github/workflows/tvmbot.yml @@ -1,5 +1,5 @@ -name: Merge +name: tvm-bot on: status: pull_request_review: @@ -12,16 +12,19 @@ concurrency: cancel-in-progress: true jobs: - maybe-merge: + run-tvm-bot: if: github.repository == 'apache/tvm' runs-on: ubuntu-20.04 + if: ${{ github.event.issue.pull_request }} steps: - uses: actions/checkout@v2 - - name: Merge if requested and possible + - name: Run tvm-bot env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TVM_BOT_JENKINS_TOKEN: ${{ secrets.TVM_BOT_JENKINS_TOKEN }} PR_NUMBER: ${{ github.event.issue.number }} + ISSUE_COMMENT: ${{ toJson(github.event.comment) }} RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} run: | set -eux - python tests/scripts/github_mergebot.py --pr "$PR_NUMBER" --run-url "$RUN_URL" + python tests/scripts/github_tvmbot.py --pr "$PR_NUMBER" --run-url "$RUN_URL" --trigger-comment-json "$ISSUE_COMMENT" diff --git a/tests/python/ci/sample_prs/pr10786-badci.json b/tests/python/ci/sample_prs/pr10786-badci.json index b49899b86bcae..7e9d10d0b6481 100644 --- a/tests/python/ci/sample_prs/pr10786-badci.json +++ b/tests/python/ci/sample_prs/pr10786-badci.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { @@ -119,6 +119,7 @@ "commit": { "oid": "6f04bcf57d07f915a98fd91178f04d9e92a09fcd" }, + "id": 123, "author": { "login": "kparzysz-quic" }, diff --git a/tests/python/ci/sample_prs/pr10786-changes-requested.json b/tests/python/ci/sample_prs/pr10786-changes-requested.json index 46b13a7f6c6c0..24e261099a4ff 100644 --- a/tests/python/ci/sample_prs/pr10786-changes-requested.json +++ b/tests/python/ci/sample_prs/pr10786-changes-requested.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { @@ -120,6 +120,7 @@ "commit": { "oid": "6f04bcf57d07f915a98fd91178f04d9e92a09fcd" }, + "id": 123, "author": { "login": "kparzysz-quic" }, diff --git a/tests/python/ci/sample_prs/pr10786-co-authors.json b/tests/python/ci/sample_prs/pr10786-co-authors.json index a660c9d9b214a..75f2728250597 100644 --- a/tests/python/ci/sample_prs/pr10786-co-authors.json +++ b/tests/python/ci/sample_prs/pr10786-co-authors.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { diff --git a/tests/python/ci/sample_prs/pr10786-invalid-author.json b/tests/python/ci/sample_prs/pr10786-invalid-author.json index d19d6dad8a442..81b028e3196ae 100644 --- a/tests/python/ci/sample_prs/pr10786-invalid-author.json +++ b/tests/python/ci/sample_prs/pr10786-invalid-author.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { @@ -114,6 +114,7 @@ "nodes": [ { "body": "@tvm-bot merge", + "id": 123, "updatedAt": "2022-03-25T22:13:50Z", "authorCanPushToRepository": false, "commit": { diff --git a/tests/python/ci/sample_prs/pr10786-merges.json b/tests/python/ci/sample_prs/pr10786-merges.json index c7b6940f0d5b3..0226c8ab52454 100644 --- a/tests/python/ci/sample_prs/pr10786-merges.json +++ b/tests/python/ci/sample_prs/pr10786-merges.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free.\n\n\nThanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.\n\n\nPreviously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\n\n\ncc @someone\n\r\n\r\nCo-authored-by: Adam Straw \n\n\nThanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.\n\n", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { diff --git a/tests/python/ci/sample_prs/pr10786-missing-job.json b/tests/python/ci/sample_prs/pr10786-missing-job.json index 81be0ebe47950..13739b793fb53 100644 --- a/tests/python/ci/sample_prs/pr10786-missing-job.json +++ b/tests/python/ci/sample_prs/pr10786-missing-job.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { diff --git a/tests/python/ci/sample_prs/pr10786-nottriggered.json b/tests/python/ci/sample_prs/pr10786-nottriggered.json index 11c5976bd6e40..0da541c4342df 100644 --- a/tests/python/ci/sample_prs/pr10786-nottriggered.json +++ b/tests/python/ci/sample_prs/pr10786-nottriggered.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { diff --git a/tests/python/ci/sample_prs/pr10786-oldreview.json b/tests/python/ci/sample_prs/pr10786-oldreview.json index 27ba0e8729181..1a2556cb6f5f1 100644 --- a/tests/python/ci/sample_prs/pr10786-oldreview.json +++ b/tests/python/ci/sample_prs/pr10786-oldreview.json @@ -3,7 +3,7 @@ "body": "- Added device validity check in allocation. HexagonDeviceAPI should only be called for CPU/Hexagon types.\r\n\r\n- Check for \"global.vtcm\" scope instead of \"vtcm\". The ccope of N-d allocations produced by `LowerVtcmAlloc` should be `\"global.vtcm\"`. The previous check allowed unsupported scope such as `\"local.vtcm\"`.\r\n\r\n- Remove `vtcmallocs` entry after calling free. Previously, the vtcm allocation map kept dangling pointers to `HexagonBuffer` objects after they had been freed.\r\n\r\n- Rename N-d alloc and free packed functions. Since most of the similar device functions use snake case, renaming `*.AllocND` to `*.alloc_nd` and `*.FreeND` to `*.free_nd`.\r\n\r\nCo-authored-by: Adam Straw ", "state": "OPEN", "author": { - "login": "Lunderberg" + "login": "abc" }, "comments": { "pageInfo": { diff --git a/tests/python/ci/sample_prs/pr11244-unauthorized-comment.json b/tests/python/ci/sample_prs/pr11244-unauthorized-comment.json index 206adc9a9eacf..beafc05958b64 100644 --- a/tests/python/ci/sample_prs/pr11244-unauthorized-comment.json +++ b/tests/python/ci/sample_prs/pr11244-unauthorized-comment.json @@ -3,7 +3,7 @@ "body": "See [this thread ](https://discuss.tvm.apache.org/t/crt-add-platform-specific-pre-and-post-function-calls-in-crt-runtime/12723)for an explanation.", "state": "OPEN", "author": { - "login": "fPecc" + "login": "abc" }, "comments": { "pageInfo": { diff --git a/tests/python/ci/sample_prs/pr11267-no-review.json b/tests/python/ci/sample_prs/pr11267-no-review.json index 31577671f0b6b..d2ad164673e5a 100644 --- a/tests/python/ci/sample_prs/pr11267-no-review.json +++ b/tests/python/ci/sample_prs/pr11267-no-review.json @@ -3,7 +3,7 @@ "body": "This adds `/opt/sccache` to the PATH of each of the CI docker images so when cmake looks for a C compiler it will pick up the sccache wrapper by default. This fixes some issues where compiler invocations weren't being run though sccache. With this approach the invoker doesn't need to do anything specific to set up sccache.\n\nThis will require a follow up PR to update the Docker images and remove some of the sccache logic in `task_build.py`\n\n\n\ncc @Mousius @areusch", "state": "OPEN", "author": { - "login": "driazati" + "login": "abc" }, "comments": { "pageInfo": { @@ -15,6 +15,7 @@ "author": { "login": "areusch" }, + "id": 124, "updatedAt": "2022-05-11T16:54:32Z", "body": "just confirming--we can disable this when doing a local build, correct? what's the mechanism by which we do that?" }, @@ -23,6 +24,7 @@ "author": { "login": "driazati" }, + "id": 123, "updatedAt": "2022-05-11T18:46:54Z", "body": "@tvm-bot merge" } diff --git a/tests/python/ci/sample_prs/pr11276-no-review.json b/tests/python/ci/sample_prs/pr11276-no-review.json deleted file mode 100644 index 3f8459eb00f7b..0000000000000 --- a/tests/python/ci/sample_prs/pr11276-no-review.json +++ /dev/null @@ -1,157 +0,0 @@ -{ - "title": "[COMMUNITY] mikepapadim -> Reviewer", - "body": "Please join us to welcome Michalis Papadimitriou (@mikepapadim) as a new reviewer to TVM. Michalis has contributed a lot to BYOC and TensorRT backend.\r\n\r\n- [Commits History](https://github.com/apache/tvm/commits?author=mikepapadim)\r\n- [Code Review](https://github.com/apache/tvm/pulls?utf8=%E2%9C%93&q=reviewed-by:mikepapadim)\r\n- [Community Forum Summary](https://github.com/apache/tvm/commits?author=mikepapadim)", - "state": "OPEN", - "author": { - "login": "ZihengJiang" - }, - "comments": { - "pageInfo": { - "hasPreviousPage": false - }, - "nodes": [] - }, - "authorCommits": { - "nodes": [ - { - "commit": { - "authors": { - "nodes": [ - { - "name": "ZihengJiang", - "email": "ziheng@apache.org" - } - ] - } - } - } - ] - }, - "commits": { - "nodes": [ - { - "commit": { - "oid": "96075744cc687caafc131361d006c5967edddbc6", - "statusCheckRollup": { - "contexts": { - "pageInfo": { - "hasNextPage": false - }, - "nodes": [ - { - "name": "MacOS", - "checkSuite": { - "workflowRun": { - "workflow": { - "name": "CI" - } - } - }, - "status": "COMPLETED", - "conclusion": "SUCCESS", - "url": "https://github.com/apache/tvm/runs/6391733373" - }, - { - "name": "cc-reviewers", - "checkSuite": { - "workflowRun": { - "workflow": { - "name": "PR" - } - } - }, - "status": "COMPLETED", - "conclusion": "SUCCESS", - "url": "https://github.com/apache/tvm/runs/6391732791" - }, - { - "name": "cc-reviewers", - "checkSuite": { - "workflowRun": { - "workflow": { - "name": "PR" - } - } - }, - "status": "COMPLETED", - "conclusion": "SUCCESS", - "url": "https://github.com/apache/tvm/runs/6391754960" - }, - { - "name": "tag-teams", - "checkSuite": { - "workflowRun": { - "workflow": { - "name": "Teams" - } - } - }, - "status": "COMPLETED", - "conclusion": "SUCCESS", - "url": "https://github.com/apache/tvm/runs/6391732788" - }, - { - "name": "tag-teams", - "checkSuite": { - "workflowRun": { - "workflow": { - "name": "Teams" - } - } - }, - "status": "COMPLETED", - "conclusion": "SUCCESS", - "url": "https://github.com/apache/tvm/runs/6391754947" - }, - { - "name": "Windows", - "checkSuite": { - "workflowRun": { - "workflow": { - "name": "CI" - } - } - }, - "status": "COMPLETED", - "conclusion": "SUCCESS", - "url": "https://github.com/apache/tvm/runs/6391733127" - }, - { - "state": "SUCCESS", - "context": "tvm-ci/branch", - "targetUrl": "https://ci.tlcpack.ai/job/tvm/job/ziheng%252Fcommunity/1/display/redirect" - }, - { - "state": "SUCCESS", - "context": "tvm-ci/pr-head", - "targetUrl": "https://ci.tlcpack.ai/job/tvm/job/PR-11276/1/display/redirect" - } - ] - } - } - } - } - ] - }, - "reviewDecision": "APPROVED", - "reviews": { - "pageInfo": { - "hasPreviousPage": false - }, - "nodes": [ - { - "body": "", - "updatedAt": "2022-05-11T16:50:16Z", - "url": "https://github.com/apache/tvm/pull/11276#pullrequestreview-969701502", - "authorCanPushToRepository": true, - "commit": { - "oid": "96075744cc687caafc131361d006c5967edddbc6" - }, - "author": { - "login": "tqchen" - }, - "state": "APPROVED" - } - ] - } -} \ No newline at end of file diff --git a/tests/python/ci/sample_prs/pr11442-no-recomment.json b/tests/python/ci/sample_prs/pr11442-rerun-ci.json similarity index 95% rename from tests/python/ci/sample_prs/pr11442-no-recomment.json rename to tests/python/ci/sample_prs/pr11442-rerun-ci.json index 77af805f2180e..0199b2921f648 100644 --- a/tests/python/ci/sample_prs/pr11442-no-recomment.json +++ b/tests/python/ci/sample_prs/pr11442-rerun-ci.json @@ -3,7 +3,7 @@ "body": "(See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for\r\ncontext, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md).\r\n\r\nThis adds a new 'DSO exportable' runtime module representing the contents of a .o file. It\r\nallows external codegen toolchains to yield a result which:\r\n - Like CSource modules, can be conveyed directly to the final export_library compilation\r\n step for linking into the final .so and saved to a know location without risk the\r\n underlying code artifact will be lost.\r\n - Like DSOLibrary modules, are self contained so that no additional compile-time arguments\r\n need be conveyed from the CSource module to the final export_library command line\r\n\r\nSince this is the third flavor of 'DSO exportable' module, add a Module::IsDSOExportable.\r\n\r\nSince adding the above, can't resist also adding a Module::ImplementsFunction virtual and\r\ncalling it from TEComplier to check if an external codegen function actually provided the\r\nimplementation it promised.\r\n\r\nNote:\r\n - I've left the existing implementation of runtime.load_module alone which\r\n relinks .o files to .so files.\r\n - Though also contained in the .o metadata, I require static libraries to always\r\n carry their list of exported function names.\r\n\r\nThis is all pretty stop gap pending a good rework of TVM to supoprt the notion of artifacts\r\nand, perhaps, build rules.\r\n", "state": "OPEN", "author": { - "login": "mbs-octoml" + "login": "abc" }, "comments": { "pageInfo": { @@ -64,15 +64,7 @@ "login": "mbs-octoml" }, "updatedAt": "2022-05-25T22:12:37Z", - "body": "Hmff." - }, - { - "authorAssociation": "NONE", - "author": { - "login": "github-actions" - }, - "updatedAt": "2022-05-25T22:12:55Z", - "body": "Cannot merge, did not find any approving reviews from users with write access on 96d4e62da5a7b78da18d0ee28cc6261d8fbf31c4" + "body": "@tvm-bot rerun" } ] }, diff --git a/tests/python/ci/test_mergebot.py b/tests/python/ci/test_mergebot.py index b9f944e897d3f..a565cc76a5c14 100644 --- a/tests/python/ci/test_mergebot.py +++ b/tests/python/ci/test_mergebot.py @@ -29,8 +29,8 @@ class TempGit: def __init__(self, cwd): self.cwd = cwd - def run(self, *args): - proc = subprocess.run(["git"] + list(args), cwd=self.cwd) + def run(self, *args, **kwargs): + proc = subprocess.run(["git"] + list(args), cwd=self.cwd, **kwargs) if proc.returncode != 0: raise RuntimeError(f"git command failed: '{args}'") @@ -50,87 +50,118 @@ def run(self, *args): "number": 10786, "filename": "pr10786-merges.json", "expected": SUCCESS_EXPECTED_OUTPUT, + "comment": "@tvm-bot merge", + "user": "abc", "detail": "Everything is fine so this PR will merge", }, "no-request": { "number": 10786, "filename": "pr10786-nottriggered.json", - "expected": "No merge requested, exiting", + "expected": "Command 'do something else' did not match anything", + "comment": "@tvm-bot do something else", + "user": "abc", "detail": "A PR for which the mergebot runs but no merge is requested", }, "bad-ci": { "number": 10786, "filename": "pr10786-badci.json", "expected": "Cannot merge, these CI jobs are not successful on", + "comment": "@tvm-bot merge", + "user": "abc", "detail": "A PR which failed CI and cannot merge", }, "old-review": { "number": 10786, "filename": "pr10786-oldreview.json", "expected": "Cannot merge, did not find any approving reviews", + "comment": "@tvm-bot merge", + "user": "abc", "detail": "A PR with passing CI and approving reviews on an old commit so it cannot merge", }, "missing-job": { "number": 10786, "filename": "pr10786-missing-job.json", "expected": "Cannot merge, missing expected jobs", + "comment": "@tvm-bot merge", + "user": "abc", "detail": "PR missing an expected CI job and cannot merge", }, "invalid-author": { "number": 10786, "filename": "pr10786-invalid-author.json", - "expected": "No merge requested, exiting", + "expected": "Comment is not from from PR author or collaborator, quitting", + "comment": "@tvm-bot merge", + "user": "not-abc", "detail": "Merge requester is not a committer and cannot merge", }, "unauthorized-comment": { "number": 11244, "filename": "pr11244-unauthorized-comment.json", - "expected": "No merge requested, exiting", + "expected": "Comment is not from from PR author or collaborator, quitting", + "comment": "@tvm-bot merge", + "user": "not-abc2", "detail": "Check that a merge comment not from a CONTRIBUTOR is rejected", }, "no-review": { "number": 11267, "filename": "pr11267-no-review.json", "expected": "Cannot merge, did not find any approving reviews from users with write access", + "comment": "@tvm-bot merge", + "user": "abc", "detail": "Check that a merge request without any reviews is rejected", }, "changes-requested": { "number": 10786, "filename": "pr10786-changes-requested.json", "expected": "Cannot merge, found [this review]", + "comment": "@tvm-bot merge", + "user": "abc", "detail": "Check that a merge request with a 'Changes Requested' review on HEAD is rejected", }, "co-authors": { "number": 10786, "filename": "pr10786-co-authors.json", "expected": "Co-authored-by: Some One ", + "comment": "@tvm-bot merge", + "user": "abc", "detail": "Check that a merge request with co-authors generates the correct commit message", }, - "no-recomment": { + "rerun-ci": { "number": 11442, - "filename": "pr11442-no-recomment.json", - "expected": "No merge requested, exiting", - "detail": "Check that comments after a failed merge don't trigger another merge", + "filename": "pr11442-rerun-ci.json", + "expected": "Rerunning ci with", + "comment": "@tvm-bot rerun", + "user": "abc", + "detail": "Start a new CI job", }, } @pytest.mark.parametrize( - ["number", "filename", "expected", "detail"], + ["number", "filename", "expected", "comment", "user", "detail"], [tuple(d.values()) for d in test_data.values()], ids=test_data.keys(), ) -def test_mergebot(tmpdir_factory, number, filename, expected, detail): - mergebot_script = REPO_ROOT / "tests" / "scripts" / "github_mergebot.py" +def test_mergebot(tmpdir_factory, number, filename, expected, comment, user, detail): + mergebot_script = REPO_ROOT / "tests" / "scripts" / "github_tvmbot.py" test_json_dir = Path(__file__).resolve().parent / "sample_prs" git = TempGit(tmpdir_factory.mktemp("tmp_git_dir")) - git.run("init") - git.run("checkout", "-b", "main") + git.run("init", stderr=subprocess.PIPE, stdout=subprocess.PIPE) + git.run("checkout", "-b", "main", stderr=subprocess.PIPE, stdout=subprocess.PIPE) git.run("remote", "add", "origin", "https://github.com/apache/tvm.git") with open(test_json_dir / filename) as f: test_data = json.load(f) + comment = { + "body": comment, + "id": 123, + "user": { + "login": user, + }, + } + collaborators = [] + proc = subprocess.run( [ str(mergebot_script), @@ -141,10 +172,17 @@ def test_mergebot(tmpdir_factory, number, filename, expected, detail): "https://example.com", "--testing-pr-json", json.dumps(test_data), + "--testing-collaborators-json", + json.dumps(collaborators), + "--trigger-comment-json", + json.dumps(comment), ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8", + env={ + "TVM_BOT_JENKINS_TOKEN": "123", + }, cwd=git.cwd, ) if proc.returncode != 0: diff --git a/tests/scripts/git_utils.py b/tests/scripts/git_utils.py index 9f2468638cade..7cd1b6b2fe596 100644 --- a/tests/scripts/git_utils.py +++ b/tests/scripts/git_utils.py @@ -19,6 +19,7 @@ import json import subprocess import re +import base64 from urllib import request from typing import Dict, Tuple, Any, Optional, List @@ -29,6 +30,27 @@ def compress_query(query: str) -> str: return query +def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = None): + print(f"Requesting POST to", url, "with", body) + headers = {} + if auth is not None: + auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}") + request.add_header("Authorization", f"Basic {auth_str}") + + if body is None: + body = "" + + req.add_header("Content-Type", "application/json; charset=utf-8") + req = request.Request(url, headers=headers, method="POST") + data = json.dumps(body) + data = data.encode("utf-8") + req.add_header("Content-Length", len(data)) + + with request.urlopen(req, data) as response: + response = json.loads(response.read()) + return response + + class GitHubRepo: def __init__(self, user, repo, token): self.token = token diff --git a/tests/scripts/github_mergebot.py b/tests/scripts/github_tvmbot.py similarity index 80% rename from tests/scripts/github_mergebot.py rename to tests/scripts/github_tvmbot.py index 76e0803efc23a..bfdbeb4039e52 100755 --- a/tests/scripts/github_mergebot.py +++ b/tests/scripts/github_tvmbot.py @@ -23,17 +23,21 @@ import logging import traceback import re -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Callable from pathlib import Path -from git_utils import git, GitHubRepo, parse_remote +from git_utils import git, GitHubRepo, parse_remote, post from cmd_utils import init_log Review = Dict[str, Any] CIJob = Dict[str, Any] +Comment = Dict[str, Any] +CommentChecker = Callable[[Comment], bool] EXPECTED_JOBS = ["tvm-ci/pr-head"] +TVM_BOT_JENKINS_TOKEN = os.environ["TVM_BOT_JENKINS_TOKEN"] +JENKINS_URL = "https://ci.tlcpack.ai/" THANKS_MESSAGE = r"(\s*)Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from \[Reviewers\]\(https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers\) by them in the pull request thread.(\s*)" @@ -41,6 +45,19 @@ def to_json_str(obj: Any) -> str: return json.dumps(obj, indent=2) +COLLABORATORS_QUERY = """ +query ($owner: String!, $name: String!, $user: String!) { + repository(owner: $owner, name: $name) { + collaborators(query: $user, first: 1) { + nodes { + login + } + } + } +} +""" + + PR_QUERY = """ query ($owner: String!, $name: String!, $number: Int!) { repository(owner: $owner, name: $name) { @@ -60,6 +77,7 @@ def to_json_str(obj: Any) -> str: author { login } + id updatedAt body } @@ -119,6 +137,7 @@ def to_json_str(obj: Any) -> str: body updatedAt url + id authorCanPushToRepository commit { oid @@ -202,6 +221,17 @@ def checker(obj, parent_key): def __repr__(self): return json.dumps(self.raw, indent=2) + def plus_one(self, comment: Dict[str, Any]): + """ + React with a thumbs up to a comment + """ + url = f"issues/comments/{comment['id']}/reactions" + data = {"content": "+1"} + if self.dry_run: + logging.info(f"Dry run, would have +1'ed to {url} with {data}") + else: + self.github.post(url, data=data) + def head_commit(self): return self.raw["commits"]["nodes"][0]["commit"] @@ -292,6 +322,19 @@ def fetch_data(self): }, )["data"]["repository"]["pullRequest"] + def search_collaborator(self, user: str) -> List[Dict[str, Any]]: + """ + Query GitHub for collaborators matching 'user' + """ + return self.github.graphql( + query=COLLABORATORS_QUERY, + variables={ + "owner": self.owner, + "name": self.repo_name, + "user": user, + }, + )["data"]["repository"]["collaborators"]["nodes"] + def comment(self, text: str) -> None: """ Leave the comment 'text' on this PR @@ -370,70 +413,8 @@ def merge(self) -> None: self.github.put(url, data=data) - def comment_can_merge(self, comment: Dict[str, Any]) -> bool: - """ - Check if a comment was left by the PR author or by a committer - """ - if comment["author"]["login"] == self.raw["author"]["login"]: - logging.info(f"Comment {comment} was from author and is mergeable") - return True - - if comment.get("authorAssociation", "") == "CONTRIBUTOR": - logging.info(f"Comment {comment} was from committer comment and is mergeable") - return True - - if comment.get("authorCanPushToRepository", False): - logging.info(f"Comment {comment} was from a committer review comment and is mergeable") - return True - - logging.info(f"Comment {comment} was not from author or committers and is not mergeable") - return False - - def merge_requested(self) -> bool: - """ - Check if this PR has had a merge requested - """ - merge_commands = [ - "merge", - "merge this", - "merge this pr", - ] - cancel_commands = [ - "cancel", - "cancel merge", - "cancel the merge", - "stop", - "stop merge", - "stop the merge", - ] - - def parse_action(comment: Dict[str, Any]) -> Optional[str]: - if comment["author"]["login"] == "github-actions": - return "commented" - - if not self.comment_can_merge(comment): - return None - - body = comment["body"] - if any(f"@tvm-bot {c}" in body for c in merge_commands): - return "merge" - - if any(f"@tvm-bot {c}" in body for c in cancel_commands): - return "cancel" - - return None - - # Check regular comments and top-level review comments - all_comments = self.raw["comments"]["nodes"] + self.reviews() - all_comments = sorted(all_comments, key=lambda comment: comment["updatedAt"]) - actions = [parse_action(comment) for comment in all_comments] - logging.info(f"Found these tvm-bot actions: {actions}") - actions = [a for a in actions if a is not None] - - if len(actions) == 0: - return False - - return actions[-1] == "merge" + def author(self) -> str: + return self.raw["author"]["login"] def find_failed_ci_jobs(self) -> List[CIJob]: # NEUTRAL is GitHub Action's way of saying cancelled @@ -502,6 +483,49 @@ def merge_if_passed_checks(self) -> None: self.comment(f"Cannot merge, CI did not pass on on {self.head_oid()}") return + def rerun_jenkins_ci(self) -> None: + url = JENKINS_URL + f"job/tvm/job/PR-{self.number}/buildWithParameters" + logging.info(f"Rerunning ci with URL={url}") + if self.dry_run: + logging.info("Dry run, not sending POST") + else: + post(url, auth=("tvm-bot", TVM_BOT_JENKINS_TOKEN)) + + +class Merge: + triggers = [ + "merge", + "merge this", + "merge this pr", + ] + + @staticmethod + def run(pr: PR): + try: + pr.merge_if_passed_checks() + except Exception as e: + if not args.dry_run: + msg = traceback.format_exc() + pr.comment( + f"Failed to process merge request in {args.run_url}\n\n
\n\n```\n{msg}\n```\n\n
" + ) + raise e + + +class Rerun: + triggers = [ + "rerun", + "rerun ci", + "re-run", + "re-run ci", + "run", + "run ci", + ] + + @staticmethod + def run(pr: PR): + pr.rerun_jenkins_ci() + if __name__ == "__main__": help = "Check if a PR has comments trying to merge it, and do so based on reviews/CI status" @@ -509,7 +533,13 @@ def merge_if_passed_checks(self) -> None: parser.add_argument("--remote", default="origin", help="ssh remote to parse") parser.add_argument("--pr", required=True, help="pr number to check") parser.add_argument("--run-url", required=True, help="workflow run URL") + parser.add_argument( + "--trigger-comment-json", required=True, help="json of the comment that triggered this run" + ) parser.add_argument("--testing-pr-json", help="(testing only) manual data for testing") + parser.add_argument( + "--testing-collaborators-json", help="(testing only) manual data for testing" + ) parser.add_argument( "--dry-run", action="store_true", @@ -518,7 +548,27 @@ def merge_if_passed_checks(self) -> None: ) args = parser.parse_args() init_log() + comment = json.loads(args.trigger_comment_json) + body = comment["body"].strip() + + # Check that the comment was addressed to tvm-bot + if not body.startswith("@tvm-bot "): + logging.info(f"Not a bot comment, '{body}' does not start with '@tvm-bot'") + exit(0) + # Find the code to run for the command from the user + user_command = body.lstrip("@tvm-bot").strip() + command_to_run = None + for command in [Merge, Rerun]: + if user_command in command.triggers: + command_to_run = command + break + + if command_to_run is None: + logging.info(f"Command '{user_command}' did not match anything") + exit(0) + + # Find the remote for querying more data about the PR remote = git(["config", "--get", f"remote.{args.remote}.url"]) logging.info(f"Using remote remote={remote}") owner, repo = parse_remote(remote) @@ -539,21 +589,34 @@ def merge_if_passed_checks(self) -> None: else: pr = PR(number=int(args.pr), owner=owner, repo=repo, dry_run=args.dry_run) + # Acknowledge the comment with a react + pr.plus_one(comment) + + # Check the comment author + comment_author = comment["user"]["login"] + if pr.author() == comment_author: + logging.info("Comment user is PR author, continuing") + else: + logging.info("Comment is not from PR author, checking collaborators") + # Get the list of collaborators for the repo filtered by the comment + # author + if args.testing_collaborators_json: + collaborators = json.loads(args.testing_collaborators_json) + else: + collaborators = pr.search_collaborator(comment_author) + logging.info(f"Found collaborators: {collaborators}") + + if len(collaborators) > 0: + logging.info("Comment is from collaborator") + else: + logging.info("Comment is not from from PR author or collaborator, quitting") + exit(0) + state = pr.state() if state != "OPEN": logging.info(f"Ignoring event on PR, state was not OPEN, instead was state={state}") exit(0) - if pr.merge_requested(): - try: - pr.merge_if_passed_checks() - except Exception as e: - if not args.dry_run: - msg = traceback.format_exc() - pr.comment( - f"Failed to process merge request in {args.run_url}\n\n
\n\n```\n{msg}\n```\n\n
" - ) - raise e - else: - logging.info("No merge requested, exiting") + # Run the command + command_to_run.run(pr) From c78539cc59b60b77794276699f9430cd5e838106 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 2 Jun 2022 15:08:13 -0500 Subject: [PATCH 026/181] [TIR][Arith] Additional Simplifications Inside Conditionals (#11524) * [TIR][Arith] Use equality constraints in analyzer Previously, constraints with inequalities were recognized and used for simplifications by `ConstIntBoundAnalyzer` and `ModularSetAnalyzer`, but constraints with equalities were not. This adds equality-based constraints. (e.g. Inside the then-case of `if i==5`, the value of `i` is known to be 5.) * [TIR][Arith] RewriteSimplifier, apply literal constraints Previously, constraints were only checked within a `tir.likely` annotation. After this change, constraints are used for simplification of all boolean expressions. (e.g. Within a conditional `if i==n`, the expression `(i==n) and (j==m)` can be simplified to `j==m`.) * [TIR][Arith] Do not apply literal constraints to BufferLoad If a literal constraint relies on the contents of a buffer, the constraint may not be assumed to hold. This prevents the incorrect rewriting of `A[i]==n` to true within a `if A[i]==n` conditional, as the value of `A[i]` may have changed. * [TIR][Arith] Use each independent constraints in RewriteSimplifier Inside a constraint `if i==n and j==m`, both `i==n` and `j==m` may be replaced with true, even in separate expressions. This commit uses a new internal utility function `tvm::arith::ExtractConstraints`, which breaks up a boolean expression into a list of true statements. This may be used to reduce duplication elsewhere, such as `const_int_bound.cc` and `iter_affine_map.cc`. * [TIR][Arith] Check for negation of literal constraints When inside a conditional of `i!=n`, in addition to the previous replacement of `i!=n` with true, we can also replace `i==n` with false. * [TIR][Arith] Added unittests for new simplifications * Fix lint error * Fixed handling of negation of non-boolean types * Removed extra asterisk --- src/arith/const_int_bound.cc | 3 + src/arith/constraint_extract.cc | 55 +++++ src/arith/constraint_extract.h | 58 +++++ src/arith/modular_set.cc | 4 + src/arith/rewrite_simplify.cc | 50 +++- src/arith/rewrite_simplify.h | 9 + .../unittest/test_tir_transform_simplify.py | 233 +++++++++++++++++- 7 files changed, 398 insertions(+), 14 deletions(-) create mode 100644 src/arith/constraint_extract.cc create mode 100644 src/arith/constraint_extract.h diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index cb125551c4683..4fd27a0fde10d 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -598,6 +598,9 @@ class ConstIntBoundAnalyzer::Impl if ((x < c).Match(cond)) { return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))}; } + if ((x == c).Match(cond) || (c == x).Match(cond)) { + return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, c.Eval()->value))}; + } if ((x && y).Match(cond)) { auto ret1 = DetectBoundInfo(x.Eval()); auto ret2 = DetectBoundInfo(y.Eval()); diff --git a/src/arith/constraint_extract.cc b/src/arith/constraint_extract.cc new file mode 100644 index 0000000000000..d0bf57497e63e --- /dev/null +++ b/src/arith/constraint_extract.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/constraint_extract.cc + */ + +#include "constraint_extract.h" + +#include +#include + +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +void CollectConstraints(const PrimExpr& expr, Analyzer* analyzer, std::vector* collect) { + collect->push_back(expr); + + PVar x, y; + if ((x && y).Match(expr)) { + CollectConstraints(x.Eval(), analyzer, collect); + CollectConstraints(y.Eval(), analyzer, collect); + } else if ((!(x || y)).Match(expr)) { + CollectConstraints(analyzer->rewrite_simplify(tir::Not(x.Eval())), analyzer, collect); + CollectConstraints(analyzer->rewrite_simplify(tir::Not(y.Eval())), analyzer, collect); + } +} + +std::vector ExtractConstraints(const PrimExpr& expr) { + std::vector out; + Analyzer analyzer; + CollectConstraints(expr, &analyzer, &out); + return out; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/constraint_extract.h b/src/arith/constraint_extract.h new file mode 100644 index 0000000000000..ea6e0a74419ce --- /dev/null +++ b/src/arith/constraint_extract.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file contraint_extract.h + * + * \brief Centralized location for extraction of constraints from a boolean expression. + */ + +#ifndef TVM_ARITH_CONSTRAINT_EXTRACT_H_ +#define TVM_ARITH_CONSTRAINT_EXTRACT_H_ + +#include + +#include + +namespace tvm { +namespace arith { + +/* \brief Returns constraints that are true if the expression is true. + * + * Utility to break up a boolean expression into independent + * constraints. + * + * Example: `i==5 && j==3` => `[i==5 && j==3, i==5, j==3]` + * Example: `i==5 || j==3` => `[i==5 || j==3]` + * Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]` + * + * Intended for use in bounds analysis or simplification within a + * conditional, or identifying independent conditionals that may be + * hoisted. + * + * \param expr The expression to be analyzers + * + * \returns A vector of independent constraints + */ +std::vector ExtractConstraints(const PrimExpr& expr); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITH_CONSTRAINT_EXTRACT_H_ diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index afc28a5ed2859..4cad570ab3359 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -112,6 +112,10 @@ class ModularSetAnalyzer::Impl : public ExprFunctorvalue, base.Eval()->value); return UpdateByIntersect(var.Eval(), entry); } + if ((var == base).Match(constraint) || (base == var).Match(constraint)) { + Entry entry(1, base.Eval()->value); + return UpdateByIntersect(var.Eval(), entry); + } return nullptr; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f9e38dee48e50..a168e1f0836ca 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -32,6 +32,7 @@ #include "../target/datatype/registry.h" #include "const_fold.h" +#include "constraint_extract.h" #include "pattern_match.h" namespace tvm { @@ -228,7 +229,24 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, // so simplify the constarint as well - literal_constraints_.push_back(operator()(constraint)); + PrimExpr new_constraint = operator()(constraint); + for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) { + if (SideEffect(subconstraint) <= CallEffectKind::kPure) { + literal_constraints_.push_back(subconstraint); + // We could apply this during TryMatchLiteralConstraint, but + // that would require performing a rewrite of each expression + // being checked. This way, we only apply a rewrite for each + // constraint being applied. + PrimExpr negation; + if (subconstraint.dtype().is_bool()) { + negation = Not(subconstraint); + } else { + negation = subconstraint == make_zero(subconstraint.dtype()); + } + negation = operator()(negation); + literal_constraints_.push_back(Not(negation)); + } + } size_t new_literal_size = literal_constraints_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { ICHECK_EQ(literal_constraints_.size(), new_literal_size); @@ -1291,11 +1309,27 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { return ret; } +Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const { + PrimExpr negation = Not(expr); + + ExprDeepEqual expr_equal; + for (const auto& constraint : literal_constraints_) { + if (expr_equal(constraint, expr)) { + return make_const(expr->dtype, true); + } + if (expr_equal(constraint, negation)) { + return make_const(expr->dtype, false); + } + } + return NullOpt; +} + PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression PVar x, y; @@ -1344,6 +1378,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression PVar x, y, z, s1, s2; @@ -1475,6 +1510,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { op = ret.as(); PrimExpr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + // Pattern var to match any expression PVar x, y; PVar lanes; @@ -1499,6 +1536,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression PVar x, y; @@ -1538,6 +1576,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); // Pattern var to match any expression PVar x, y; @@ -1602,13 +1641,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { return op->args[0] << op->args[1]; } } - ExprDeepEqual expr_equal; if (op->op.same_as(tir::builtin::likely())) { - for (const auto& constraint : literal_constraints_) { - // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } - if (expr_equal(constraint, op->args[0])) { - return make_const(op->dtype, true); - } + // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } + if (auto match = TryMatchLiteralConstraint(op->args[0])) { + return match.value(); } } return ret; diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 202b9209da6df..6007b6416742c 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -105,6 +105,15 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { */ bool CanInlineLet(const LetNode* op); + /*! \brief Internal function to apply constraints + * + * Tests whether the expression is known to be true or false based + * on existing constraints. If the expression or its negation + * matches a constraint, return the boolean it should be replaced + * with. Otherwise, return false. + */ + Optional TryMatchLiteralConstraint(const PrimExpr& expr) const; + private: // Whether x >= val bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 01cc41c7cec75..4f727cd89b123 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -136,7 +136,24 @@ def sls(n, d): assert "if" not in str(stmt) -def test_load_store_noop(): +class BaseBeforeAfter: + def test_simplify(self): + before = self.before + before_mod = tvm.IRModule.from_expr(before) + after_mod = tvm.tir.transform.Simplify()(before_mod) + after = after_mod["main"] + expected = self.expected + + try: + tvm.ir.assert_structural_equal(after, expected) + except ValueError as err: + script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script() + raise ValueError( + f"Function after simplification did not match expected:\n{script}" + ) from err + + +class TestLoadStoreNoop(BaseBeforeAfter): """Store of a value that was just read from the same location is a no-op.""" @T.prim_func @@ -147,11 +164,8 @@ def before(A: T.Buffer[(1,), "float32"]): def expected(A: T.Buffer[(1,), "float32"]): T.evaluate(0) - after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"] - tvm.ir.assert_structural_equal(after, expected) - -def test_load_store_noop_after_simplify(): +class TestLoadStoreNoopAfterSimplify(BaseBeforeAfter): """As test_load_store_noop, but requiring simplification to identify. Previously, a bug caused the self-assignment of a buffer to @@ -168,8 +182,213 @@ def before(A: T.Buffer[(1,), "float32"]): def expected(A: T.Buffer[(1,), "float32"]): T.evaluate(0) - after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"] - tvm.ir.assert_structural_equal(after, expected) + +class TestNestedCondition(BaseBeforeAfter): + """Nested IfThenElse with the same condition can be simplified. + + Requires const_int_bound to narrow scope of i within the + conditional, or for rewrite_simplify to recognize the literal + constraint. + """ + + @T.prim_func + def before(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + if i == 5: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + A[i] = 0.0 + + +class TestNestedProvableCondition(BaseBeforeAfter): + """Simplify inner conditional using constraint from outer. + + Requires const_int_bound to narrow scope of i within the + conditional. + """ + + @T.prim_func + def before(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + if i < 7: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + A[i] = 0.0 + + +class TestNestedVarCondition(BaseBeforeAfter): + """Simplify inner conditional using constraint from outer. + + Requires for rewrite_simplify to recognize the repeated + constraint. + """ + + @T.prim_func + def before(A: T.Buffer[(16,), "float32"], n: T.int32): + for i in T.serial(16): + if i == n: + if i == n: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(16,), "float32"], n: T.int32): + for i in T.serial(16): + if i == n: + A[i] = 0.0 + + +class TestAlteredBufferContents(BaseBeforeAfter): + """No simplification of data-dependent conditionals. + + A literal constraint must not be propagated if the values + referenced may change. TIR requires single assignment of + variables, so Var objects may be assumed constant, but BufferLoad + may not. + """ + + @T.prim_func + def before(A: T.Buffer[(1,), "int32"], n: T.int32): + if A[0] == n: + A[0] = A[0] + 1 + if A[0] == n: + A[0] = 0 + + expected = before + + +class TestNegationOfCondition(BaseBeforeAfter): + """Use negation of outer condition to simplify innner. + + Within the body of an if statement, the negation of the + condition is known to be false. + """ + + @T.prim_func + def before(A: T.Buffer[(16,), "int32"]): + for i in T.serial(16): + if i == 5: + if i != 5: + A[i] = 0 + else: + A[i] = 1 + + @T.prim_func + def expected(A: T.Buffer[(16,), "int32"]): + for i in T.serial(16): + if i == 5: + A[i] = 1 + + +class TestNegationOfNotEqual(BaseBeforeAfter): + """As TestNegationOfVarCondition, but with a != outer condition. + + Because ConstIntBoundAnalyzer only tracks the min and max allowed + values, the outer i!=5 condition does provide a constraint on the + bounds. This test relies on RewriteSimplifier to recognize + ``i==5`` as the negation of a literal constraint. + """ + + @T.prim_func + def before(A: T.Buffer[(16,), "int32"]): + for i in T.serial(16): + if i != 5: + if i == 5: + A[i] = 0 + else: + A[i] = 1 + + @T.prim_func + def expected(A: T.Buffer[(16,), "int32"]): + for i in T.serial(16): + if i != 5: + A[i] = 1 + + +class TestNegationOfVarCondition(BaseBeforeAfter): + """As TestNegationOfVarCondition, but with a dynamic condition. + + This simplification cannot be done with ConstIntBoundAnalyzer, and + must rely on RewriteSimplifier recognizing the repeated literal. + """ + + @T.prim_func + def before(A: T.Buffer[(16,), "int32"], n: T.int32): + for i in T.serial(16): + if i == n: + if i != n: + A[i] = 0 + else: + A[i] = 1 + + @T.prim_func + def expected(A: T.Buffer[(16,), "int32"], n: T.int32): + for i in T.serial(16): + if i == n: + A[i] = 1 + + +class TestLiteralConstraintSplitBooleanAnd(BaseBeforeAfter): + """Split a boolean AND into independent constraints + + A single if condition may impose multiple literal constraints. + Each constraint that is ANDed together to form the condition + should be treated as an independent constraint. The use of n in + the condition is to ensure we exercise RewriteSimplifier. + """ + + @T.prim_func + def before(A: T.Buffer[(16, 16), "int32"], n: T.int32): + for i, j in T.grid(16, 16): + if i == n and j == n: + if i == n: + A[i, j] = 0 + + @T.prim_func + def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): + for i, j in T.grid(16, 16): + if i == n and j == n: + A[i, j] = 0 + + +class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter): + """Split a boolean OR into independent constraints + + Similar to TestLiteralConstraintSplitBooleanAnd, but splitting a + boolean OR into independent conditions. This uses the + simplification that ``!(x || y) == !x && !y``. + + The use of ``n`` in the condition is to ensure we exercise + RewriteSimplifier. + """ + + @T.prim_func + def before(A: T.Buffer[(16, 16), "int32"], n: T.int32): + for i, j in T.grid(16, 16): + if i == n or j == n: + A[i, j] = 0 + else: + if i == n: + A[i, j] = 1 + else: + A[i, j] = 2 + + @T.prim_func + def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): + for i, j in T.grid(16, 16): + if i == n or j == n: + A[i, j] = 0 + else: + A[i, j] = 2 if __name__ == "__main__": From 12a0f3edcf8295288f4aa9ec3dbb6771c3a1a301 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 2 Jun 2022 14:34:23 -0700 Subject: [PATCH 027/181] [TIR] Add schedule primitive ReIndex (#11515) --- include/tvm/tir/schedule/schedule.h | 13 + python/tvm/tir/schedule/schedule.py | 73 +++ src/tir/schedule/concrete_schedule.cc | 10 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 15 + .../schedule/primitive/cache_read_write.cc | 468 ++++++++++++++++++ src/tir/schedule/schedule.cc | 5 + src/tir/schedule/traced_schedule.cc | 12 + src/tir/schedule/traced_schedule.h | 2 + src/tir/schedule/transform.cc | 26 + src/tir/schedule/transform.h | 21 + .../unittest/test_tir_schedule_reindex.py | 203 ++++++++ 12 files changed, 850 insertions(+) create mode 100644 tests/python/unittest/test_tir_schedule_reindex.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 48014280a5589..68900e107d7c9 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /*! + * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. + * The layout of the cache will be the same as by the iterators of the block that reads/writes the + * buffer. It requires: + * 1) There is only one block who reads/writes the target buffer + * 2) There is only one buffer load/store of this buffer in the block + * \param block_rv The block operates on the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \return The reindex stage block. + */ + virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index f86228848b9d2..4179088aa534d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1056,6 +1056,79 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + @type_checked + def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV: + """Create a block that read/write a buffer region into a read/write cache with reindexing. + The layout of the cache will be the same as by the iterators of the block that reads/writes + the buffer. It requires: + 1) There is only one block who reads/writes the target buffer + 2) There is only one buffer load/store of this buffer in the block + + Parameters + ---------- + block: BlockRV + The block that accesses the target buffer + buffer_index: int + The index of the buffer in block's read or write region + buffer_index_type : str + Type of the buffer index, "read" or "write" + + Returns + ------- + reindex_block : BlockRV + The block of the reindex stage + + Examples + -------- + + Before transform_layout, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] * 2.0 + + Create the schedule and do transform_layout: + + .. code-block:: python + + sch = tir.Schedule(before_reindex) + block = sch.get_block("B") + sch.reindex(block, 0, "read) + + After applying reindex, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_reindex( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + A_reindex = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("A_reindex"): + vi, vj = T.axis.remap("SS", [i, j]) + A_reindex[vi, vj] = A[vj, vi] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_reindex[vi, vj] * 2.0 + + """ + assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 + return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member + self, block, buffer_index, buffer_index_type_enum + ) + ########## Schedule: Compute location ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 2289899c329bb..590a0f0025954 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); + TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8e83aac2ce823..70c0265611c31 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 50dedf71ff528..f4dba69c6b156 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); +/*! + *! + * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. + * The layout of the cache will be the same as by the iterators of the block that reads/writes the + * buffer. It requires: + * 1) There is only one block who reads/writes the target buffer + * 2) There is only one buffer load/store of this buffer in the block + * \param self The state of the schedule + * \param block_rv The block operates on the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \return The reindex stage block. + */ +TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type); /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 1bba2ae4fc611..c96f88e1f6333 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -160,6 +160,121 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, return block; } +/*! + * \brief Create the reindex block and generate the corresponding outer loops. + * \details The reindex block is a data copy block between the reindex buffer (the intermediate + * buffer), and the target buffer. + If buffer_index_type == kWrite, copy from the reindex buffer to the target buffer. + If buffer_index_type == kRead, copy from the target buffer to the reindex buffer. + The reindex block has the same block iters and the surrounding loops as the input block. + However, if a block iter is not used in the indices of the target buffer being reindexed, the + domain of the block iter, and the corresponding outer loop, will become constant value one, making + it a trivial iter. + * \param block The block to be reindexed + * \param info The cache info + * \param covered The set of block iter vars covered in the buffer access indices + * \param original_indices The original buffer access indices + * \param buffer_index The index of the target buffer + * \param buffer_index_type The type of buffer index + * \return The reindex block. + */ +Block MakeReIndexStage(const Block& block, CacheStageInfo* info, + const std::unordered_set& covered, + const Array& original_indices, int buffer_index, + BufferIndexType buffer_index_type) { + // iters of the reindex block + Array new_block_iters; + // the substition map from the original block iter to the iters of the reindex block + std::unordered_map block_var_replace_map; + // block access region of reindexed buffer and target buffer + Region reindex_region, target_region; + // indices to access the reindex buffer and the target buffer + Array reindex_indices, target_indices; + + // Step 1: Create block iters, access regions of the reindex block, and accessing indices to the + // reindex buffer. + for (const IterVar& iter : block->iter_vars) { + Var var("v" + std::to_string(new_block_iters.size())); + bool used = covered.count(iter->var); + new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom : Range::FromMinExtent(0, 1), + /*var=*/var, + /*IterVarType=*/kDataPar)); + if (used) { + reindex_indices.push_back(var); + reindex_region.push_back(Range::FromMinExtent(var, 1)); + } + block_var_replace_map[iter->var] = var; + } + + // Step 2: Replace the original block iters with the new block iters + BufferRegion buffer_region = buffer_index_type == BufferIndexType::kWrite + ? block->writes[buffer_index] + : block->reads[buffer_index]; + target_region = Substitute(buffer_region->region, block_var_replace_map); + for (const PrimExpr& index : original_indices) { + target_indices.push_back(Substitute(index, block_var_replace_map)); + } + + // Step 3: Create the reindex block + + // The src and the dst region and indices of the data copy + Region src_region{nullptr}; + Region dst_region{nullptr}; + Array src_indices{nullptr}; + Array dst_indices{nullptr}; + + if (buffer_index_type == BufferIndexType::kWrite) { + src_region = reindex_region; + dst_region = target_region; + src_indices = reindex_indices; + dst_indices = target_indices; + } else { + src_region = target_region; + dst_region = reindex_region; + src_indices = target_indices; + dst_indices = reindex_indices; + } + + // Create the body block + Block new_block( + /*iter_vars=*/new_block_iters, + /*reads=*/ + {BufferRegion(info->read_buffer, src_region)}, + /*writes=*/ + {BufferRegion(info->write_buffer, dst_region)}, + /*name_hint=*/buffer_region->buffer->name + "_reindex", + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, src_indices), dst_indices)); + + // Step 4: Create surrounding loops + + // Create loop vars and bindings for block iters + std::vector loop_vars; // loop variables + std::vector iter_values; // bindings in block realize + for (int i = 0; i < static_cast(block->iter_vars.size()); ++i) { + Var loop_var("ax" + std::to_string(loop_vars.size())); + loop_vars.push_back(loop_var); + iter_values.push_back(loop_var); + } + + // Create the block realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/new_block); + + // Create the chain of loops + for (int i = static_cast(new_block_iters.size()) - 1; i >= 0; --i) { + body = For(/*loop_var=*/loop_vars[i], + /*min=*/new_block_iters[i]->dom->min, + /*extent=*/new_block_iters[i]->dom->extent, + /*kind=*/ForKind::kSerial, + /*body=*/std::move(body)); + } + // Update cache info, which will be used in the later rewriting. + info->cache_stage = std::move(body); + return new_block; +} + /*! * \brief Recalculate the `affine_binding` flag of a specifc block * \param block_sref The sref to the specific block @@ -599,6 +714,252 @@ class CacheWriteRewriter : public StmtExprMutator { bool under_writer_block_{false}; }; +/*! + * \brief Create a new buffer by change the shape with block iters to be used as the reindex buffer + * \param buffer The given buffer. + * \param block_iters The block iters. + * \param covered Set of block iter vars covered by the buffer access indices + * \return The new buffer with target shape. + */ +Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_iters, + const std::unordered_set& covered) { + ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_var = make_object(*buffer->data.get()); + std::vector new_shape; + std::vector new_strides; + for (const auto& iter : block_iters) { + if (covered.count(iter->var)) { + new_shape.push_back(iter->dom->min + iter->dom->extent); + } + } + new_strides.clear(); + new_buffer->shape = new_shape; + new_buffer->strides = new_strides; + new_buffer->data = buffer->data.copy_with_suffix("_reindex"); + new_buffer->name = buffer->name + "_reindex"; + return Buffer(new_buffer); +} + +/*! \brief The schedule error that the target is not a leaf block. */ +class NotLeafBlockError : public ScheduleError { + public: + NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} + String FastErrorString() const final { + return "ScheduleError: The target block is not a leaf block."; + } + + String DetailRenderTemplate() const final { return "The target block {0} is not a leaf block."; } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; +}; + +/*! \brief The schedule error that the buffer access is invalid for reindex. */ +class InvalidBufferAccessError : public ScheduleError { + public: + enum class ErrorKind { + kNoAccess, // buffer access not found + kNonUniqueAccess, // multiple buffer accesses with different indices + kOpaqueAccess, // opaque access to the buffer + }; + + InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind) + : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {} + String FastErrorString() const final { + return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The " + "indices should be the same if there are multiple accesses to the target buffer."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_->name + << " should be accessed in the leaf block {0} via BufferLoad or BufferStore. The indices " + "should be the same if there are multiple accesses to the target buffer. "; + if (kind_ == ErrorKind::kNoAccess) { + os << "No buffer accesses found."; + } else if (kind_ == ErrorKind::kNonUniqueAccess) { + os << "Multiple buffer accesses have non-unique indices."; + } else if (kind_ == ErrorKind::kOpaqueAccess) { + os << "Opaque buffer accesses found."; + } + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; + ErrorKind kind_; +}; + +/*! \brief Collect the related Load/Store to reindex */ +class ReIndexCollector : public StmtExprVisitor { + public: + static Array Collect(const IRModule& mod, const Buffer& buffer, const Block& block) { + ReIndexCollector collector(mod, buffer, block); + collector(block->body); + if (!collector.buffer_access_indices_.defined()) { + throw InvalidBufferAccessError(mod, buffer, block, + InvalidBufferAccessError::ErrorKind::kNoAccess); + } + return collector.buffer_access_indices_.value(); + } + + private: + explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const Block& block) + : mod_(mod), buffer_(buffer), block_(block) {} + + void VisitExpr_(const BufferLoadNode* load) final { + StmtExprVisitor::VisitExpr_(load); + if (load->buffer.same_as(buffer_)) { + CheckAndUpdateBufferAccessIndices(load->indices); + } + } + + void VisitStmt_(const BlockNode* block) final { + // no sub-blocks under this block + throw NotLeafBlockError(mod_, block_); + } + + void VisitStmt_(const BufferStoreNode* store) final { + StmtExprVisitor::VisitStmt_(store); + if (store->buffer.same_as(buffer_)) { + CheckAndUpdateBufferAccessIndices(store->indices); + } + } + + void CheckAndUpdateBufferAccessIndices(const Array indices) { + if (!buffer_access_indices_.defined()) { + buffer_access_indices_ = indices; + return; + } else if (!std::equal(buffer_access_indices_.value().begin(), + buffer_access_indices_.value().end(), indices.begin(), indices.end(), + ExprDeepEqual())) { + throw InvalidBufferAccessError(mod_, buffer_, block_, + InvalidBufferAccessError::ErrorKind::kNonUniqueAccess); + } + } + + void VisitExpr_(const VarNode* var) final { + if (var == buffer_->data.get()) { + throw InvalidBufferAccessError(mod_, buffer_, block_, + InvalidBufferAccessError::ErrorKind::kOpaqueAccess); + } + } + /*! \brief The IR module */ + IRModule mod_; + /*! \brief The buffer to rewrite */ + Buffer buffer_; + /*! \brief The block to visit */ + Block block_; + /*! \brief The indices of buffer acess to rewrite */ + Optional> buffer_access_indices_; +}; + +/*! \brief Mutator of ReIndex */ +class ReIndexRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info, + const std::unordered_set& covered) { + ReIndexRewriter rewriter(block_sref, info, covered); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info, + const std::unordered_set& covered) + : block_sref_(block_sref), info_(info), covered_(covered) { + new_buffer_ = info->alloc; + old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + if (is_scope_) { + is_scope_ = false; + Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + // Insert cache stage into the loop + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + info_->block_reuse.Set(old_stmt, stmt); + return stmt; + } + + // Visiting the blokc being reindexed + if (block == block_sref_->stmt) { + // Collect the updated indices and regions + for (const IterVar& iter : block->iter_vars) { + if (covered_.count(iter->var)) { + indices_.push_back(iter->var); + region_.push_back(Range::FromMinExtent(iter->var, 1)); + } + } + Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + // Update block reads/writes to use the intermediate reindex buffer + auto writes = + ReplaceBufferRegion(block->writes, old_buffer_, BufferRegion{new_buffer_, region_}); + auto reads = + ReplaceBufferRegion(block->reads, old_buffer_, BufferRegion{new_buffer_, region_}); + auto match_buffers = ReplaceBufferRegion(block->match_buffers, old_buffer_, + BufferRegion{new_buffer_, region_}); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { + ObjectPtr n = make_object(*stmt.as()); + n->writes = std::move(writes); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + stmt = Block(n); + } + info_->block_reuse.Set(old_stmt, stmt); + return stmt; + } + return old_stmt; + } + + template + Node VisitBufferAccess(Node node) { + if (node->buffer.same_as(old_buffer_)) { + auto* n = node.CopyOnWrite(); + n->buffer = new_buffer_; + n->indices = indices_; + } + return node; + } + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore buffer_store = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(buffer_store)); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad buffer_load = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(buffer_load)); + } + + private: + /*! \brief The parent scope of the insertion. */ + const StmtSRef& block_sref_; + /*! \brief The info for inserting reindex stage. */ + CacheStageInfo* info_; + /*! \brief Whether old block var is covered in the indices */ + const std::unordered_set& covered_; + /*! \brief Whether the current block is scope block */ + bool is_scope_{true}; + /*! \brief The buffer to be replaced */ + Buffer old_buffer_; + /*! \brief The reindex buffer */ + Buffer new_buffer_; + /*! \brief The new indices */ + Array indices_; + /*! \brief The new region */ + Region region_; +}; + /******** Implementation ********/ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, @@ -729,6 +1090,80 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Block block = GetRef(block_ptr); + Buffer buffer = + GetNthAccessBuffer(self, block, buffer_index, buffer_index_type == BufferIndexType::kWrite); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + arith::Analyzer analyzer; + + // Step 1. Collect the original indices and check there's only single pattern of related + // Load/Store and the buffer is not accessed opaquely + Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); + // Simplify the indices if possible + for (const IterVar& iter : block->iter_vars) { + analyzer.Bind(iter->var, iter->dom); + } + original_indices.MutateByApply( + [&analyzer](const PrimExpr& expr) { return analyzer.Simplify(expr); }); + + // Collect block iters appearing in the original_indices + std::unordered_set covered; + for (const PrimExpr& index : original_indices) { + PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { + if (const VarNode* var = obj.as()) { + covered.insert(GetRef(var)); + } + return true; + }); + } + + // Step 2. Creating CacheStageInfo + CacheStageInfo info; + // Create the corresponding buffer to be read(write), i.e. the result of reindex read(write) + if (buffer_index_type == BufferIndexType::kWrite) { + info.read_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered); + info.write_buffer = buffer; + info.alloc = info.read_buffer; + } else { + info.read_buffer = buffer; + info.write_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered); + info.alloc = info.write_buffer; + } + + // Step 3. Check the block belongs to a chain loop nesting under the scope, + // and get the insert location + const StmtSRefNode* loop; + for (loop = block_sref->parent; loop->parent != scope_sref.get();) { + const ForNode* outer = loop->parent->StmtAs(); + const ForNode* inner = loop->StmtAs(); + ICHECK(outer != nullptr && inner != nullptr); + ICHECK(outer->body.get() == inner); + loop = loop->parent; + } + + info.loc_pos = loop->seq_index == -1 ? 0 : loop->seq_index; + if (buffer_index_type == BufferIndexType::kWrite) { + info.loc_pos++; + } + + // Step 4. Making new reindex stage block and rewrite + Block reindex_stage = + MakeReIndexStage(block, &info, covered, original_indices, buffer_index, buffer_index_type); + Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info, covered); + + // Step 5. Replacing and updating flags + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + /******** Instruction Registration ********/ struct CacheReadTraits : public UnpackedInstTraits { @@ -787,7 +1222,40 @@ struct CacheWriteTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReIndexTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReIndex"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, + Integer buffer_index_type) { + return sch->ReIndex(block, buffer_index, + static_cast(buffer_index_type->value)); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, + Integer buffer_index_type) { + PythonAPICall py("reindex"); + py.Input("block", block); + py.Input("buffer_index", buffer_index); + py.Input("buffer_index_type", '"' + + std::string(BufferIndexType2Str( + static_cast(buffer_index_type->value))) + + '"'); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index fb884ce77f7b7..3880d0b19eeb8 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") + .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, + int buffer_index_type) { + return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); + }); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 8156480a4516b..d2f627edfd11d 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -265,6 +265,18 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) { + BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); + + static const InstructionKind& kind = InstructionKind::Get("ReIndex"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)}, + /*outputs=*/{result})); + return result; +} + /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d1860be9512d7..ba4a4b99cbb2d 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -73,6 +73,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 79802ecd65dbb..67d0f55f20b9f 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -70,6 +70,32 @@ Array ReplaceBuffer(Array match_buffers, c return match_buffers; } +Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, + const BufferRegion& target) { + regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) -> BufferRegion { + if (region->buffer.same_as(source_buffer)) { + return target; + } + return region; + }); + return regions; +} + +Array ReplaceBufferRegion(Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target) { + match_buffers.MutateByApply([&source_buffer, &target]( + const MatchBufferRegion& match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source_buffer)) { + ObjectPtr n = make_object(*match_buffer.get()); + n->source = target; + return MatchBufferRegion(n); + } + return match_buffer; + }); + return match_buffers; +} + /******** ReplaceBufferMutator ********/ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, Map* block_sref_reuse) diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 192d44d9e9adc..908a823c2d860 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -73,6 +73,27 @@ Array ReplaceBuffer(Array regions, const Buffer& sou Array ReplaceBuffer(Array match_buffers, const Buffer& source, const Buffer& target); +/*! + * \brief Replaces the buffer region within the specific sequence of regions + * \param regions The regions to be replaced + * \param source_buffer The buffer to whose region is to be replaced + * \param target The buffer region to be replaced to + * \return The new sequence of regions after replacement + */ +Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, + const BufferRegion& target); + +/*! + * \brief Replaces the buffer region within the specific sequence of match_buffers + * \param regions The match_buffers to be replaced + * \param source_buffer The buffer to whose region is to be replaced + * \param target The buffer region to be replaced to + * \return The new sequence of match_buffers after replacement + */ +Array ReplaceBufferRegion(Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target); + /*! * \brief A helper mutator which recursively replaces the old buffer with the new buffer and * collects the block sref reuse information for the following replacement. diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py new file mode 100644 index 0000000000000..9b2e37a19813a --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reindex.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.schedule import ScheduleError +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def transpose_elementwise( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] * 2.0 + + +@T.prim_func +def transpose_elementwise_reindex_read( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + A_reindex = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("A_reindex"): + vi, vj = T.axis.remap("SS", [i, j]) + A_reindex[vi, vj] = A[vj, vi] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_reindex[vi, vj] * 2.0 + + +@T.prim_func +def conv2d_nhwc( + Input: T.Buffer[(1, 224, 224, 3), "float32"], + Weight: T.Buffer[(7, 7, 3, 64), "float32"], + Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), + Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + Conv2d_nhwc[n, h, w, co] = T.float32(0) + Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( + PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co, 64) * 3) + rc)] + * Weight[rh, rw, rc, co] + ) + + +@T.prim_func +def conv2d_nhwc_reindex_weight( + var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle +) -> None: + inputs = T.match_buffer(var_inputs, [1, 224, 224, 3], dtype="float32") + weight = T.match_buffer(var_weight, [7, 7, 3, 64], dtype="float32") + conv2d_nhwc = T.match_buffer(var_conv2d_nhwc, [1, 112, 112, 64], dtype="float32") + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, + inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], + T.float32(0), + dtype="float32", + ) + for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 64, 7, 7, 3): + with T.block("weight_reindex"): + v0, v1, v2, v3, v4, v5, v6 = T.axis.remap( + "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6] + ) + T.reads(weight[v4, v5, v6, v3]) + T.writes(weight_reindex[v3, v4, v5, v6]) + weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads( + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], + weight_reindex[co, rh, rw, rc], + ) + T.writes(conv2d_nhwc[n, h, w, co]) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float32(0) + conv2d_nhwc[n, h, w, co] = ( + conv2d_nhwc[n, h, w, co] + + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] + * weight_reindex[co, rh, rw, rc] + ) + + +@T.prim_func +def matmul( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], +) -> None: + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("matmul"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + +@T.prim_func +def matmul_reindex_write( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], +) -> None: + C_reindex = T.alloc_buffer([512, 512], dtype="float32") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("matmul"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C_reindex[i, j], A[i, k], B[k, j]) + T.writes(C_reindex[i, j]) + with T.init(): + C_reindex[i, j] = T.float32(0) + C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j] + for i0, i1, i2 in T.grid(512, 512, 1): + with T.block("C_reindex"): + v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(C_reindex[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_reindex[v0, v1] + + +@T.prim_func +def multiple_read(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + A[vi, vj] + + +def test_reindex_read_basic(): + sch = tir.Schedule(transpose_elementwise) + block = sch.get_block("B") + sch.reindex(block, 0, "read") + tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=transpose_elementwise) + + +def test_conv2d_reindex_read(): + sch = tir.Schedule(conv2d_nhwc) + block = sch.get_block("conv2d_nhwc") + sch.reindex(block, 1, "read") + tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) + + +def test_matmul_reindex_write(): + sch = tir.Schedule(matmul) + block = sch.get_block("matmul") + sch.reindex(block, 0, "write") + tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=matmul) + + +def test_reindex_fail_multiple_read(): + sch = tir.Schedule(multiple_read) + block = sch.get_block("B") + with pytest.raises(ScheduleError): + sch.reindex(block, 0, "read") + + +if __name__ == "__main__": + tvm.testing.main() From aff1312e365142bcb77d6ae847753702a4e3a0c6 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 2 Jun 2022 14:37:11 -0700 Subject: [PATCH 028/181] [PROFILER] Fix percent compute bound calculation (#11542) * [PROFILER] Fix percent compute bound calculation Somehow the runtime was dropped from the percent compute bound calculation. Tolerances on the test we bumped a little bit higher to try and catch mistakes like this in the future. * forgot print --- python/tvm/utils/roofline.py | 2 +- tests/python/unittest/test_runtime_profiling.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/utils/roofline.py b/python/tvm/utils/roofline.py index 6d1ac753e27e5..8a17b9f003123 100644 --- a/python/tvm/utils/roofline.py +++ b/python/tvm/utils/roofline.py @@ -392,7 +392,7 @@ def roofline_from_existing( compute_bound = arith_inten > ridge_point call["Bound"] = "compute" if compute_bound else "memory" per_mem_bound = (loaded_bytes / runtime) / peak_bandwidth * 100 - per_compute_bound = flops / peak_flops * 100.0 + per_compute_bound = (flops / runtime) / peak_flops * 100.0 # We use ratio here because the percentages should be averaged instead of summed. call["Percent of Theoretical Optimal"] = profiling.Ratio( per_compute_bound if compute_bound else per_mem_bound diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 919057f08d27c..29a8414337756 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -328,7 +328,7 @@ def test_roofline_analysis(target, dev): # Ideally we'd like a little tighter bound here, but it is hard to # know how well this dense will perform without tuning. And we # don't have an operator that uses a specific number of flops. - assert call["Percent of Theoretical Optimal"].ratio >= 0 + assert call["Percent of Theoretical Optimal"].ratio >= 5.0 @tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386") @@ -354,7 +354,7 @@ def test_roofline_analysis_rpc(): # Ideally we'd like a little tighter bound here, but it is hard to # know how well this dense will perform without tuning. And we # don't have an operator that uses a specific number of flops. - assert call["Percent of Theoretical Optimal"].ratio >= 0 + assert call["Percent of Theoretical Optimal"].ratio >= 5.0 if __name__ == "__main__": From 017d410bd18fd3e272ea49ea9e11955c3128bb72 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 2 Jun 2022 14:37:40 -0700 Subject: [PATCH 029/181] Fix docker/lint.sh after #10933. (#11541) --- docker/lint.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/lint.sh b/docker/lint.sh index a968bc1e6421b..4f7bca445a9fd 100755 --- a/docker/lint.sh +++ b/docker/lint.sh @@ -20,7 +20,7 @@ source "$(dirname $0)/dev_common.sh" SCRIPT_NAME="$0" -DEFAULT_STEPS=( file_type asf cpplint clang_format pylint python_format jnilint cppdocs mypy ) +DEFAULT_STEPS=( file_type asf clang_format cpplint python_format pylint jnilint cppdocs mypy ) inplace_fix=0 @@ -43,12 +43,12 @@ function run_lint_step() { ;; clang_format) if [ $inplace_fix -eq 0 ]; then - cmd=( tests/lint/clang_format.sh ) + cmd=( tests/lint/git-clang-format.sh ) else # NOTE: need to run git status to update some docker-side cache. Otherwise, # git-clang-format will fail with "The following files would be modified but have # unstaged changes:" - cmd=( bash -c 'git status &>/dev/null && tests/lint/git-clang-format.sh -i origin/main' ) + cmd=( bash -c 'git status &>/dev/null && tests/lint/git-clang-format.sh -i --rev origin/main' ) fi ;; cpplint) @@ -62,9 +62,9 @@ function run_lint_step() { ;; python_format) if [ $inplace_fix -eq 0 ]; then - cmd=( tests/lint/python_format.sh ) + cmd=( tests/lint/git-black.sh ) else - cmd=( tests/lint/git-black.sh -i origin/main ) + cmd=( tests/lint/git-black.sh -i --rev origin/main ) fi ;; jnilint) From f31477f9c3c5ad618750ad6d43b6d6020f6b44d6 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 2 Jun 2022 16:47:20 -0700 Subject: [PATCH 030/181] [FIX] Pad feature vectors to the same size in xgboost cost model (#11479) * [FIX] Pad feature vectors to the same size in xgboost cost model * add test * more test * explaination * formatting --- .../tvm/autotvm/tuner/xgboost_cost_model.py | 24 +++++++++++++------ python/tvm/testing/autotvm.py | 11 ++++++--- .../unittest/test_autotvm_xgboost_model.py | 4 ++++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 637891854aee0..d4942ce6a4ca0 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -243,18 +243,27 @@ def fit_log(self, records, plan_size, min_seed_records=500): else: raise RuntimeError("Invalid feature type: " + self.fea_type) result = pool.map_with_error_catching(feature_extract_func, data) + result = list(result) # store results so we can iterate through them twice - # filter out feature with different shapes - fea_len = len(self._get_feature([0])[0]) + # get maximum feature length + fea_len = -1 + for res in result: + if res.status != StatusKind.COMPLETE: + continue + x, _ = res.value + fea_len = max(fea_len, x.shape[0]) xs, ys = [], [] for res in result: if res.status != StatusKind.COMPLETE: continue x, y = res.value - if len(x) == fea_len: + # Features may not be the same size, pad them until they are + if fea_len > len(x): + xs.append(np.pad(x, (0, fea_len - len(x)))) + else: xs.append(x) - ys.append(y) + ys.append(y) if len(xs) < min_seed_records: # no enough samples return False @@ -329,15 +338,16 @@ def _get_feature(self, indexes): for i, fea in zip(need_extract, feas): fea_cache[i] = fea.value if fea.status == StatusKind.COMPLETE else None - feature_len = None + feature_len = -1 for idx in indexes: if fea_cache[idx] is not None: - feature_len = fea_cache[idx].shape[-1] - break + feature_len = max(fea_cache[idx].shape[-1], feature_len) ret = np.empty((len(indexes), feature_len), dtype=np.float32) for i, ii in enumerate(indexes): t = fea_cache[ii] + if t.shape[0] < feature_len: + t = np.pad(t, (0, feature_len - t.shape[0])) ret[i, :] = t if t is not None else 0 return ret diff --git a/python/tvm/testing/autotvm.py b/python/tvm/testing/autotvm.py index 6f7bb13fe6dca..b1132cd1faa7f 100644 --- a/python/tvm/testing/autotvm.py +++ b/python/tvm/testing/autotvm.py @@ -62,9 +62,14 @@ def matmul(N, L, M, dtype): # schedule according to config yo, yi = cfg["tile_y"].apply(s, C, y) - xo, xi = cfg["tile_x"].apply(s, C, x) - - s[C].reorder(yo, xo, k, yi, xi) + # Make sure configurations have a varied number of itervars. Splitting adds + # new itervars, so conditionally splitting with cause the number of + # itervars to depend on the tile size. + if cfg["tile_x"].size[-1] > 1: + xo, xi = cfg["tile_x"].apply(s, C, x) + s[C].reorder(yo, xo, k, yi, xi) + else: + s[C].reorder(yo, k, yi, x) return s, [A, B, C] diff --git a/tests/python/unittest/test_autotvm_xgboost_model.py b/tests/python/unittest/test_autotvm_xgboost_model.py index baecdaceab6d3..7fa3daede07e1 100644 --- a/tests/python/unittest/test_autotvm_xgboost_model.py +++ b/tests/python/unittest/test_autotvm_xgboost_model.py @@ -43,6 +43,10 @@ def test_fit(): upper_model.fit(xs, ys, plan_size=32) + # feature lengths are not guaranteed to always be the same + upper_model.predict(np.ones(12)) + upper_model.predict(np.ones(8)) + def fit_spawn(): assert multiprocessing.get_start_method(False) == "spawn" From 274d8fa964489e03ad97e684902063d935bf192b Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 2 Jun 2022 16:56:50 -0700 Subject: [PATCH 031/181] Unbreak CI image build (tensorflow 2.6.5, ci_gpu bugfix) (#11546) * Pin protobuf to 3.20.1 due to #11545. * Unpin and instead update to 2.6.5 * attempt to fix gpu build * Revert to 2.6.3, pin protobuf for ci-arm. * escape bash char --- docker/Dockerfile.ci_gpu | 2 +- docker/install/ubuntu_install_tensorflow.sh | 2 +- docker/install/ubuntu_install_tensorflow_aarch64.sh | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 73d13007f1d06..e0d1997de729b 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -24,7 +24,7 @@ FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04 RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub # Base scripts -RUN rm /etc/apt/sources.list.d/nvidia-ml.list && apt-get clean +RUN rm -f /etc/apt/sources.list.d/nvidia-ml.list && apt-get clean RUN apt-get update --fix-missing COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index eaf89ffcf8fef..17d2b31d9bc24 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -23,4 +23,4 @@ set -o pipefail pip3 install \ "h5py==3.1.0" \ keras==2.6 \ - tensorflow==2.6.2 + tensorflow==2.6.5 diff --git a/docker/install/ubuntu_install_tensorflow_aarch64.sh b/docker/install/ubuntu_install_tensorflow_aarch64.sh index 6acf8b7270d81..8d5b6765deb05 100755 --- a/docker/install/ubuntu_install_tensorflow_aarch64.sh +++ b/docker/install/ubuntu_install_tensorflow_aarch64.sh @@ -26,5 +26,6 @@ apt-get install -y --no-install-recommends libhdf5-dev pip3 install \ "h5py==3.1.0" \ keras==2.6 \ - tensorflow-aarch64==2.6.2 \ + tensorflow-aarch64==2.6.3 \ + "protobuf<4" \ -f https://snapshots.linaro.org/ldcg/python-cache/tensorflow-aarch64/ From 2ae20882d3e34cc6e5acef992c23c17a585c25aa Mon Sep 17 00:00:00 2001 From: Christian Convey Date: Fri, 3 Jun 2022 11:58:30 -0400 Subject: [PATCH 032/181] [hexagon][testing] add TIRScript elemwise-add (#11490) Replace TE-based elementwise-add benchmark with a TVMScript-based one. Update Hexagon target architecture from v68 to v69. As a result, the benchmark now requires a version of Hexagon SDK newer than 4.4.0.1. Version 4.5.0.3 is known to work. --- .../test_hexagon/benchmark_elemwise_add.py | 434 ++++++++++++++++++ .../contrib/test_hexagon/benchmark_hexagon.py | 245 ---------- .../contrib/test_hexagon/benchmark_util.py | 34 ++ 3 files changed, 468 insertions(+), 245 deletions(-) create mode 100644 tests/python/contrib/test_hexagon/benchmark_elemwise_add.py delete mode 100644 tests/python/contrib/test_hexagon/benchmark_hexagon.py diff --git a/tests/python/contrib/test_hexagon/benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/benchmark_elemwise_add.py new file mode 100644 index 0000000000000..70266d7939bc5 --- /dev/null +++ b/tests/python/contrib/test_hexagon/benchmark_elemwise_add.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import os.path +import sys +import pytest +import numpy as np +import logging +import tempfile + +import tvm.testing +import tvm.script +from tvm.script import tir as T +from tvm import te +from tvm.contrib.hexagon.build import HexagonLauncherRPC +from . import benchmark_util + +# This is a fixed detail of the v68 architecture. +HVX_VECTOR_BYTES = 128 + +_HEXAGON_TARGET = tvm.target.hexagon("v69", link_params=True) + +_SUPER_TARGET = tvm.target.Target(_HEXAGON_TARGET, host=_HEXAGON_TARGET) + +# NOTE on server ports: +# These tests use different port numbers for the RPC server (7070 + ...). +# The reason is that an RPC session cannot be gracefully closed without +# triggering TIME_WAIT state on the server socket. This prevents another +# server to bind to the same port until the wait time elapses. + +_BT = benchmark_util.BenchmarksTable() + +_CSV_COLUMN_ORDER = [ + # Identifies which TE-compute / TIRScript is used as the basis for the + # benchmarked primfunc. Only needs to be meaningful to humans. + "basic_kernel", + # The tensors' element type + "dtype", + # When applicable, indicates the particular variation of schedules + # apply by the Python code. Decoding this may require looking at this + # script's source code. + "sched_type", + # The memory location of the tensors used during the execution of + # the primfunc. We currently assume just one location. + # This will likely need to be generalized as we add more sophisticated + # primfuncs. + "mem_scope", + # For primfuncs that treat tensor buffers as collections of 1D vectors, + # this is the number of vectors in each tensor. + # This will likely need to be generalized as we add more sophisticated + # primfuncs. + "num_vectors_per_tensor", + # Reserved columns defined by the BenchmarksTable class. + "row_status", + "timings_min_usecs", + "timings_max_usecs", + "timings_median_usecs", + "timings_mean_usecs", + "timings_stddev_usecs", + # For benchmarks that produce files on the host file system, this indicates + # their location. Useful for post-mortem investigation of benchmark results. + "host_files_dir_path", + # Miscellaneous comments about the benchmark. + "comments", +] + +_HOST_OUTPUT_DIR = tempfile.mkdtemp() + +_PRIMFUNC_NAME = "elemwise_add" + +print("-" * 80) +print("OUTPUT DIRECTORY: {}".format(_HOST_OUTPUT_DIR)) +print("-" * 80) +print() + + +class UnsupportedException(Exception): + """ + Indicates that the specified benchmarking configuration is known to + currently be unsupported. The Exception message may provide more detail. + """ + + +class NumericalAccuracyException(Exception): + """ + Indicates that the benchmarking configuration appeared to run successfully, + but the output data didn't have the expected accuracy. + """ + + +from typing import Tuple + + +def _get_irmod_elemwise_add( + _PRIMFUNC_NAME: str, shape: list, dtype: str, mem_scope: str +) -> tvm.ir.module.IRModule: + """ + Return an IRModule containing a single primfunc, expressed as NS-TIR. + + The primfunc implements elementwise-add. Its signature is (A,B,C), where + A and B are the input tensors, and C is the output tensor. + All three tensors have the specfied shape, dtype, and mem_scope. + + If the specified primfunc is known to be unsupported, raise an UnsupportedExcetion. + """ + assert len(shape) == 2 + + # TVMScript can reference simple Python variables, but it doesn't + # curently support more complex Python expressions... + ( + dim0_size, + dim1_size, + ) = shape + dtype_str = str(dtype) + + if mem_scope == "global.vtcm": + raise UnsupportedException("This benchmark kernel does not yet support VTCM buffers.") + + # This check is currently elided by the one above, but it should become relevant as soon + # as we add VTCM support to this kernel generator. + # + # Also: The VTCM budget is a very rough estimate, based only on experience. + # Assuming that it's even reasonable to use a hard-coded estimate AT ALL, this number + # may need tweaking. + estimated_vtcm_budget_bytes = HVX_VECTOR_BYTES * 1024 + + dtype_bits = tvm._ffi.runtime_ctypes.DataType(dtype).bits + assert dtype_bits % 8 == 0 + dtype_bytes = dtype_bits // 8 + + num_vtcm_tensors = 3 + estimated_vtcm_needed_bytes = shape[0] * shape[1] * dtype_bytes * num_vtcm_tensors + + if estimated_vtcm_needed_bytes > estimated_vtcm_budget_bytes: + raise UnsupportedException("Expect to exceed VTCM budget.") + + @tvm.script.ir_module + class BenchmarkModule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + + A = T.match_buffer(a, shape, dtype=dtype) + B = T.match_buffer(b, shape, dtype=dtype) + C = T.match_buffer(c, shape, dtype=dtype) + + for i in range(dim0_size): + for j in range(dim1_size): + C[i, j] = A[i, j] + B[i, j] + + return BenchmarkModule + + +def _benchmark_hexagon_elementwise_add_kernel( + hexagon_launcher: HexagonLauncherRPC, shape: list, dtype: str, mem_scope: str +): + """ + Generate and benchmark a single elementwise-add kernel for Hexagon. + + Produce these outputs: + - Printed status updates / results to stdout and/or stderr. + + - Create a new subdirectory under _HOST_OUTPUT_DIR, and populate it with + various logs and intermediate files. + + - Add to _BT a row describing this benchmark run. + """ + # Represent the benchmark details in a form required by the benchmark table + # and for other logging... + keys_dict = { + "basic_kernel": "ewise-add", + "dtype": dtype, + "shape": shape, + "mem_scope": mem_scope, + } + + desc = benchmark_util.get_benchmark_decription(keys_dict) + + # Create the host-side directory for this benchmark run's files / logs... + host_files_dir_name = benchmark_util.get_benchmark_id(keys_dict) + host_files_dir_path = os.path.join(_HOST_OUTPUT_DIR, host_files_dir_name) + os.mkdir(host_files_dir_path) + + keys_dict["host_files_dir_path"] = host_files_dir_path + + log_file_path = os.path.join(host_files_dir_path, "out.txt") + with open(log_file_path, "w") as log_file: + print(f"CONFIGURATION: {desc}") + log_file.write(f"CONFIGURATION: {desc}\n") + + try: + ns_tir_module = _get_irmod_elemwise_add(_PRIMFUNC_NAME, shape, dtype, mem_scope) + + # Dump the primfunc NS-TIR (as text) to the log file... + lowered_mod = tvm.lower(ns_tir_module, _PRIMFUNC_NAME) + log_file.write("LOWERED IR MODULE:\n") + log_file.write(str(lowered_mod)) + log_file.write("\n") + + # Lower the primfunc's IRModule to Hexagon object code... + A = tvm.te.placeholder(shape, dtype=dtype) + B = tvm.te.placeholder(shape, dtype=dtype) + C = tvm.te.placeholder(shape, dtype=dtype) + + built_module: tvm.driver.build_module.OperatorModule = tvm.build( + ns_tir_module, + [ + A, + B, + C, + ], + _SUPER_TARGET, + name=_PRIMFUNC_NAME, + ) + + # Create an actual Hexagon-native shared object file, initially stored on the + # host's file system... + host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") + built_module.save(host_dso_binary_path) + print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") + + # Upload the .so to the Android device's file system (or wherever is appropriate + # when using the Hexagon simulator)... + target_dso_binary_filename = "test_binary.so" + hexagon_launcher.upload(host_dso_binary_path, target_dso_binary_filename) + + # Generate our testing / validation data... + ( + host_numpy_A_data, + host_numpy_B_data, + host_numpy_C_data_expected, + ) = _get_elemwise_add_reference_value_tensors(shape, dtype) + + with hexagon_launcher.start_session() as sess: + # On the target device / simulator, make our Hexagon-native shared object + # available for use... + loaded_hexagon_module: tvm.runtime.module.Module = hexagon_launcher.load_module( + target_dso_binary_filename, sess + ) + + # Create the target-side tensors to hold the primfunc's inputs and outputs... + A_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) + B_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) + C_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) + + # Populate the primfunc's input tensors... + A_data.copyfrom(host_numpy_A_data) + B_data.copyfrom(host_numpy_B_data) + + # Actually benchmark the primfunc... + timer = loaded_hexagon_module.time_evaluator( + "main", sess.device, number=10, repeat=1 + ) + timing_result = timer(A_data, B_data, C_data) + + print(f"TIMING RESULT: {timing_result}") + log_file.write(f"TIMING RESULT: {timing_result}\n") + + # Verify that the computation actually happened, and produced the correct result. + result = C_data.numpy() + + if dtype == "float16": + # These are the closest tolerance we currently expect / require for these + # kernels. They may be changed in the future. + rel_tolerance = 0.005 + abs_tolerance = 2.0 + elif dtype == "int8": + rel_tolerance = 0 + abs_tolerance = 0 + else: + raise Exception(f"Unexpected dtype: {dtype}") + + # TODO: We're assuming that *any* assertion thrown by 'assert_allclose' is because + # the numerical differences were too large. But ideally this code would + # differentiate between (a) numerical difference errors, which should simply be + # recorded as a failed benchmark run, vs. (b) more serious errors that should + # kill the overall script. + try: + tvm.testing.assert_allclose( + result, host_numpy_C_data_expected, rel_tolerance, abs_tolerance + ) + except AssertionError as e: + raise NumericalAccuracyException(str(e)) + + _BT.record_success(timing_result, **keys_dict) + + except NumericalAccuracyException as e: + print() + print(f"FAIL: Numerical accuracy error. See log file.") + + log_file.write("\n") + log_file.write(f"FAIL: {e}\n") + + _BT.record_fail(**keys_dict, comments=f"Numerical accuracy error. See log file.") + + except UnsupportedException as e: + print() + print(f"SKIP: {e}") + + log_file.write("\n") + log_file.write(f"SKIP: {e}\n") + + _BT.record_skip(**keys_dict, comments=f"Unsupported configuration: {e}") + + +def _get_elemwise_add_reference_value_tensors(shape: list, dtype: str): + """ + Return [A:np.array, B:np.array, C:np.array] + + `A`, `B`, and `C` are reference data used to exercise and validate + an elementwise-add kernel: C = A+B. + + NOTE: These data are primarily meant for performance testing. + The values may be helpful in detecting correctness issues, but that's + a secondary consideration here. + """ + assert len(shape) == 2 + + A = np.ndarray(shape, dtype=dtype) + B = np.ndarray(shape, dtype=dtype) + + np_dtype = A.dtype + + if np_dtype.kind in ["i", "u"]: + # We allow overflow for integer types because it tends to be well-behaved + # and well-understood... + min_value = np.iinfo(np_dtype).min + max_value = np.iinfo(np_dtype).max + + next_value = min_value + + for i in range(shape[0]): + for j in range(shape[1]): + A[i, j] = next_value + B[i, j] = next_value * 2 + next_value += 1 + + elif np_dtype.kind == "f": + # NOTE: For simplicity, we avoid test data that that require + # well-defined behavior on floating-point overflow. + # But it may be reasonable to test that in the future. + min_value = np.finfo(np_dtype).min + max_value = np.finfo(np_dtype).max + + min_input_value = min_value / 2.0 + 1 + max_input_value = max_value / 2.0 - 2 + delta = (max_input_value - min_input_value) / (shape[0] * shape[1]) + + next_value = min_input_value + + for i in range(shape[0]): + for j in range(shape[1]): + A[i, j] = next_value + B[i, j] = next_value + 1 + next_value += delta + + else: + assert False, f"Unexpected data type: {np_dtype}" + + C = A + B + return [ + A, + B, + C, + ] + + +@tvm.testing.requires_hexagon +def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC): + for dtype in [ + "int8", + "float16", + ]: + + for mem_scope in [ + "global", + "global.vtcm", + ]: + + # These numbers are fairly arbitrary, but they're meant to stress memory/caches to + # various extents. + for num_vectors_per_tensor in [ + 1, + 16, + 64, + 512, + 2048, + ]: + + dtype_bits = tvm._ffi.runtime_ctypes.DataType(dtype).bits + assert dtype_bits % 8 == 0 + dtype_bytes = dtype_bits // 8 + + elem_per_hvx_vector = HVX_VECTOR_BYTES // dtype_bytes + + shape = [ + num_vectors_per_tensor, + elem_per_hvx_vector, + ] + + print() + _benchmark_hexagon_elementwise_add_kernel(hexagon_launcher, shape, dtype, mem_scope) + + print("-" * 80) + print(f"OUTPUT DIRECTORY: {_HOST_OUTPUT_DIR}") + print("-" * 80) + print() + + tabular_output_filename = os.path.join(_HOST_OUTPUT_DIR, "benchmark-results.csv") + with open(tabular_output_filename, "w") as csv_file: + _BT.print_csv(csv_file, _CSV_COLUMN_ORDER) + + print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}") + + _BT.print_csv(sys.stdout, _CSV_COLUMN_ORDER) + + if _BT.has_fail() > 0: + pytest.fail("At least one benchmark configuration failed", pytrace=False) diff --git a/tests/python/contrib/test_hexagon/benchmark_hexagon.py b/tests/python/contrib/test_hexagon/benchmark_hexagon.py deleted file mode 100644 index 2a1d6796e7315..0000000000000 --- a/tests/python/contrib/test_hexagon/benchmark_hexagon.py +++ /dev/null @@ -1,245 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import os.path -import sys -import pytest -import numpy as np -import logging -import tempfile - -import tvm.testing -from tvm import te -from tvm.contrib.hexagon.build import HexagonLauncherRPC -from .benchmark_util import BenchmarksTable - -RPC_SERVER_PORT = 7070 - -# This is a fixed detail of the v68 architecture. -HVX_VECTOR_BYTES = 128 - -# NOTE on server ports: -# These tests use different port numbers for the RPC server (7070 + ...). -# The reason is that an RPC session cannot be gracefully closed without -# triggering TIME_WAIT state on the server socket. This prevents another -# server to bind to the same port until the wait time elapses. - - -@tvm.testing.requires_hexagon -def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC): - """ - Starting with an elementwise-add computation, try various schedules / optimizations to - see the impact they have on performance. - - The main motivation for this test is to explore the relationship between these - schedules / optimizations vs. how effectively the primfunc uses the Hexagon's - HVX units. - """ - host_output_dir = tempfile.mkdtemp() - - print("-" * 80) - print("OUTPUT DIRECTORY: {}".format(host_output_dir)) - print("-" * 80) - print() - - bt = BenchmarksTable() - - # Create and benchmark a single primfunc. - # If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'bt'. - def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor): - version_name = f"dtype:{dtype}-schedtype:{sched_type}-memscope:{mem_scope}-numvecs:{num_vectors_per_tensor}" - print() - print(f"CONFIGURATION: {version_name}") - - if num_vectors_per_tensor == 2048 and mem_scope == "global.vtcm": - bt.record_skip( - dtype=dtype, - sched_type=sched_type, - mem_scope=mem_scope, - num_vectors_per_tensor=num_vectors_per_tensor, - comments="Expect to exceed VTCM budget.", - ) - return - - dtype_bits = tvm._ffi.runtime_ctypes.DataType(dtype).bits - assert dtype_bits % 8 == 0 - dtype_bytes = dtype_bits // 8 - - elem_per_hvx_vector = HVX_VECTOR_BYTES // dtype_bytes - - # Note! We're providing the complete input tensor shapes now, - # whereas the original code only reveals the exact shape when - # about to call the kernel. - - shape = [ - num_vectors_per_tensor, - elem_per_hvx_vector, - ] - - A = tvm.te.placeholder(shape, dtype=dtype) - B = tvm.te.placeholder(shape, dtype=dtype) - C = tvm.te.compute(A.shape, lambda i, j: A[i, j] + B[i, j], name="C") - - sched = tvm.te.create_schedule(C.op) - - if sched_type == 1: - pass - elif sched_type == 2: - sched[C].vectorize(C.op.axis[1]) - else: - raise Exception("Unknown schedule type") - - # If we're using VTCM, we *must* add a transform_layout step to the schedule. - # Otherwise the generated code will crash. - # As of 2022-04-12 the crash does not provide a useful error message to the - # host Python code. - if mem_scope == "global.vtcm": - for tensor in [A, B, C]: - sched[tensor].transform_layout(lambda i, j: [i, te.AXIS_SEPARATOR, j]) - - # This module is only created so humans can inspect its IR. - module_for_ir_dump = tvm.lower(sched, [A, B, C], "foo") - - report_path = os.path.join(host_output_dir, f"{version_name}.txt") - - with open(report_path, "w") as f: - f.write("LOWERED IR MODULE:\n") - f.write(str(module_for_ir_dump)) - f.write("\n") - - target_hexagon = tvm.target.hexagon("v68", link_params=True) - func = tvm.build( - sched, - [A, B, C], - tvm.target.Target(target_hexagon, host=target_hexagon), - name="elemwise_add", - ) - - host_dso_binary_path = os.path.join(host_output_dir, f"test_binary-{version_name}.so") - target_dso_binary_filename = "test_binary.so" - - func.save(str(host_dso_binary_path)) - print("SAVED BINARY TO HOST PATH: {}".format(str(host_dso_binary_path))) - - hexagon_launcher.upload(host_dso_binary_path, target_dso_binary_filename) - - try: - with hexagon_launcher.start_session() as sess: - mod = hexagon_launcher.load_module(target_dso_binary_filename, sess) - - host_numpy_A_data = np.ndarray(shape, dtype=dtype) - host_numpy_B_data = np.ndarray(shape, dtype=dtype) - - for i in range(shape[0]): - for j in range(shape[1]): - host_numpy_A_data[i, j] = i + j - host_numpy_B_data[i, j] = (i + 1) * (j + 1) - - host_numpy_C_data_expected = host_numpy_A_data + host_numpy_B_data - - A_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) - A_data.copyfrom(host_numpy_A_data) - - B_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) - B_data.copyfrom(host_numpy_B_data) - - C_data = tvm.nd.empty(shape, dtype, sess.device, mem_scope) - - # NOTE: We may want to soften these numbers, depending on future findings. - timer = mod.time_evaluator("elemwise_add", sess.device, number=10, repeat=1) - timing_result = timer(A_data, B_data, C_data) - - # Verify that the computation actually happened, and produced the correct result. - result = C_data.numpy() - tvm.testing.assert_allclose(host_numpy_C_data_expected, result) - - bt.record_success( - timing_result, - dtype=dtype, - sched_type=sched_type, - mem_scope=mem_scope, - num_vectors_per_tensor=num_vectors_per_tensor, - ) - - except Exception as err: - f.write("ERROR:\n") - f.write("{}\n".format(err)) - bt.record_fail( - dtype=dtype, - sched_type=sched_type, - mem_scope=mem_scope, - num_vectors_per_tensor=num_vectors_per_tensor, - comments=f"See {report_path}", - ) - - # ----------------------------------------------------------------------------------------------- - - csv_column_order = [ - "dtype", - "sched_type", - "mem_scope", - "num_vectors_per_tensor", - "row_status", - "timings_min_usecs", - "timings_max_usecs", - "timings_median_usecs", - "timings_mean_usecs", - "timings_stddev_usecs", - "comments", - ] - - # Hexagon v69 allows more dtypes, but we're sticking with v68 for now. - for dtype in [ - "int8", - ]: - - # These numbers are only meaningful in the context of this script. - for sched_type in [ - 1, - 2, - ]: - - for mem_scope in ["global", "global.vtcm"]: - - # These numbers are fairly arbitrary, but they're meant to stress memory/caches to - # various extents. - for num_vectors_per_tensor in [ - 1, - 16, - 64, - 512, - 2048, - ]: - - test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor) - - # Report our progress. - bt.print_csv(sys.stdout, csv_column_order) - - print("-" * 80) - print(f"OUTPUT DIRECTORY: {host_output_dir}") - print("-" * 80) - print() - - tabular_output_filename = os.path.join(host_output_dir, "benchmark-results.csv") - with open(tabular_output_filename, "w") as csv_file: - bt.print_csv(csv_file, csv_column_order) - print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}") - - if bt.has_fail() > 0: - pytest.fail("At least one benchmark configuration failed", pytrace=False) diff --git a/tests/python/contrib/test_hexagon/benchmark_util.py b/tests/python/contrib/test_hexagon/benchmark_util.py index 5a75e9a6e80fb..113c7780c130f 100644 --- a/tests/python/contrib/test_hexagon/benchmark_util.py +++ b/tests/python/contrib/test_hexagon/benchmark_util.py @@ -139,3 +139,37 @@ def print_csv(self, f, column_name_order, timing_decimal_places=3): csv_line_dict[col_name] = str_value writer.writerow(csv_line_dict) + + +def get_benchmark_id(keys_dict): + """ + Given a dictionary with the distinguishing characteristics of a particular benchmark + line item, compute a string that uniquely identifies the benchmark. + + The returned string: + - is a valid directory name on the host's file systems, and + - should be easy for humans to parse + + Note that the insertion order for `keys_dict` affects the computed name. + """ + # Creat a copy, because we might be modifying it. + d = dict(keys_dict) + + # Sniff for shape-like lists, because we want them in a form that's both + # readable and filesystem-friendly... + for k, v in d.items(): + if isinstance(v, list) or isinstance(v, tuple): + v2 = "_".join([str(x) for x in v]) + d[k] = v2 + + return "-".join([f"{k}:{v}" for k, v in d.items()]) + + +def get_benchmark_decription(keys_dict): + """ + Similar to `get_benchmark_id`, but the focus is on human-readability. + + The returned string contains no line-breaks, but may contain spaces and + other characters that make it unsuitable for use as a filename. + """ + return " ".join([f"{k}={v}" for k, v in keys_dict.items()]) From b086005f8f9d439ff8397dcc6b048fd8dda5a995 Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Fri, 3 Jun 2022 10:58:28 -0700 Subject: [PATCH 033/181] [ci] Fix action expressions for tvm-bot workflow (#11556) These weren't caught by `actionlint` for some reason but GitHub doesn't merge multiple `if`s, so this combines them into one. Co-authored-by: driazati --- .github/workflows/tvmbot.yml | 3 +-- tests/scripts/git_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tvmbot.yml b/.github/workflows/tvmbot.yml index c9d2cf71e6a70..784f6899a3be3 100644 --- a/.github/workflows/tvmbot.yml +++ b/.github/workflows/tvmbot.yml @@ -13,9 +13,8 @@ concurrency: jobs: run-tvm-bot: - if: github.repository == 'apache/tvm' + if: ${{ github.event.issue.pull_request && github.repository == 'apache/tvm' }} runs-on: ubuntu-20.04 - if: ${{ github.event.issue.pull_request }} steps: - uses: actions/checkout@v2 - name: Run tvm-bot diff --git a/tests/scripts/git_utils.py b/tests/scripts/git_utils.py index 7cd1b6b2fe596..0e2e85e552431 100644 --- a/tests/scripts/git_utils.py +++ b/tests/scripts/git_utils.py @@ -33,15 +33,15 @@ def compress_query(query: str) -> str: def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = None): print(f"Requesting POST to", url, "with", body) headers = {} + req = request.Request(url, headers=headers, method="POST") if auth is not None: - auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}") - request.add_header("Authorization", f"Basic {auth_str}") + auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}".encode()) + req.add_header("Authorization", f"Basic {auth_str}") if body is None: body = "" req.add_header("Content-Type", "application/json; charset=utf-8") - req = request.Request(url, headers=headers, method="POST") data = json.dumps(body) data = data.encode("utf-8") req.add_header("Content-Length", len(data)) From 9dceb4e191c5588046c1478243d031f0b6052311 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Fri, 3 Jun 2022 13:19:53 -0700 Subject: [PATCH 034/181] [BYOC] Two helper passes for external codegen using RelayToTIR custom pass machinery (#11474) * [BYOC] Two helper passes for external codegen using RelayToTIR custom pass machinery (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). For reasons explained in the above thread I'm moving CUTLASS to be IRModule-at-a-time external codegen using a custom RelayToTIR pass instead of the traditional function-at-a-time external codegen using a relay.ext.cutlass registered function. This means some of the rewriing done on-the-fly by LowerTEPass now needs to be done by the custom pass directly. This PR supplies two passes which ease that burden: - Before starting the CUTLASS-specific processing, make sure all "Compiler" attributed functions have unique global definitions (ie are outlined). Though functions start in this form after BYOC partitioning, under Graph and AOT compilation flows those functions are then inlined to pass through the 'codegen' keyhole which assumes the whole model is just one self-contained main function. This pass will undo that. (I gave up trying to just remove the inlining in the first place.) - After the CUTLASS-specific processing the now compiled "Compiler" attributed functions need to marked as 'extern'. The te_compiler.cc uses the "ExternalSymbol" attribute for that, but since a) the symbol name is never needed, on the presense of the attribute is significant downstream and b) "ExternalSymbol" is easy to confuse with "global_symbol", I just replaced "ExternalSymbol" with "Extern" with an Integer(1) (cf "Primitive"). The outlining pass is a little more general than necessary because it (will also) be used by Collage to rewrite the IRModule into optimally partitioned form while making maximal reuse of partition functions. Hence the abstract GlobalSymbolCache. * - Andrew's comments --- include/tvm/ir/expr.h | 3 +- include/tvm/relay/attrs/call.h | 2 +- include/tvm/relay/function.h | 32 ++- python/tvm/relay/transform/transform.py | 70 ++++-- src/ir/expr.cc | 3 +- src/parser/tokenizer.h | 4 +- src/relay/backend/te_compiler.cc | 8 +- src/relay/backend/vm/compiler.cc | 4 +- src/relay/ir/function.cc | 2 +- src/relay/op/nn/nn.cc | 1 + .../transforms/compiler_function_utils.cc | 212 ++++++++++++++++++ .../transforms/compiler_function_utils.h | 135 +++++++++++ src/relay/transforms/dead_code.cc | 6 +- src/relay/transforms/inline.cc | 5 +- .../transform/test_compiler_function_utils.py | 162 +++++++++++++ 15 files changed, 608 insertions(+), 41 deletions(-) create mode 100644 src/relay/transforms/compiler_function_utils.cc create mode 100644 src/relay/transforms/compiler_function_utils.h create mode 100644 tests/python/relay/transform/test_compiler_function_utils.py diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4a00de802c61e..b54a067e1c941 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; // PrimExprs that are useful as runtime containers. diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h index 167a593ff377b..e0b347de17837 100644 --- a/include/tvm/relay/attrs/call.h +++ b/include/tvm/relay/attrs/call.h @@ -35,7 +35,7 @@ namespace relay { * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. */ struct CallLoweredAttrs : public tvm::AttrsNode { - /*! \brief The metadata attached to the call node. */ + /*! \brief Additional metadata attached to the call node. Should be replaced by explict fields. */ Map metadata; TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 5869f878aa856..052d04fe24119 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -170,19 +170,40 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func); * \brief namespace of the attributes that can be attached to a relay::Function. */ namespace attr { -/*! \brief Mark the function as a primitive function. */ + +/*! + * \brief Mark the function as representing a sub-graph which is to be lowered or compiled as + * a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc. + * If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below. + * The function body should be considered opaque by Relay, and many passes simply ignore these + * functions. + * + * Type: Integer + */ constexpr const char* kPrimitive = "Primitive"; + +/*! + * \brief Mark the function as externally implemented, ie bound in a runtime::Module within the + * IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally + * the only attribute when present. + * + * Type: Integer + */ +constexpr const char* kExtern = "Extern"; + /*! - * \brief Indicate the compiler that should be used for building this function. - * When this is unset or set to "default", the default compilation pipeline will be used. + * \brief Indicates the name of the external codegen 'compiler' that should be used to lower + * or compile the function other than TVM's default lowering pipeline. The name may correspond + * to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'. + * + * Type: String */ constexpr const char* kCompiler = "Compiler"; + /*! \brief Indicate if the function is a closure. */ constexpr const char* kClosure = "Closure"; /*! \brief Store a Var to parameter/Constant mapping on a Function. */ constexpr const char* kParams = "__params__"; -/*! \brief Store the unique external symbol for external compilers. */ -constexpr const char* kExternalSymbol = "ExternalSymbol"; /*! \brief Mark if the function should be avoided being optimized. */ constexpr const char* kSkipOptimization = "SkipOptimization"; /*! \brief Treat the function as a composite operator. */ @@ -193,6 +214,7 @@ constexpr const char* kInline = "Inline"; constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; + } // namespace attr } // namespace relay diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 9f253f8e88ba7..694dbb45218ca 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -802,24 +802,6 @@ def Inline(): return _ffi_api.Inline() -def InlineComposites(target): - """Perform inlining on the given Relay IR module. The functions originate - from the MergeComposite pass based on an input pattern table will fold back - to main. Currently, this is used for the TRT BYOC which expects a single - primitive function to operate on. - - Parameters - ---------- - target: str - The byoc target for which ops need to fold back to primitive function. - Returns - ------- - ret: tvm.transform.Pass - The registered pass that performs inlining for a Relay IR module. - """ - return _ffi_api.InlineComposites(target) - - def gradient(expr, mod=None, mode="higher_order"): """ Transform the input function, @@ -1386,3 +1368,55 @@ def SplitArgs(max_function_args): The registered pass for constant folding. """ return _ffi_api.SplitArgs(max_function_args) + + +def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): + """Outlines all literal functions in direct call positions which have a "Compiler" + attribute. + + The outlined functions are bound to unique global vars according to their existing + "global_symbol" attribute. At most one function with the same global symbol is outlined. + + If compiler_filter is non-empty only functions with that as their attribute value are + outlined. + + This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism + to prepare the IRModule before custom lowering. + + Parameters + ---------- + compiler_filter : String + If non-empty, the 'compiler' attribute to filter on. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter) + + +def MarkCompilerFunctionsAsExtern(compiler_filter=""): + """Marks all global functions which have a "Compiler" attribute matching + compiler_filter as 'extern'. + + The function's attributes are replaced with a single "Extern" attribute, and + all calls to the function are switched to use the 'call_lowered' calling convention. + + If compiler_filter is non-empty only functions with that as their attribute value are + outlined. + + This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to + cleanup the IRModule after custom lowering. + + Parameters + ---------- + compiler_filter : String + If non-empty, the 'compiler' attribute to filter on. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 399873492f041..a3318bf94fc66 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint, Type type) { +GlobalVar::GlobalVar(String name_hint, Type type, Span span) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->checked_type_ = std::move(type); + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 4ac1ceef26dce..505784e4bf70e 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -295,8 +295,6 @@ struct Tokenizer { int line = this->line; int column = this->col; - ICHECK_EQ(Peek(), '['); - Next(); std::stringstream type_key; while (More() && Peek() != ']') { type_key << Next(); @@ -498,7 +496,7 @@ struct Tokenizer { auto token = NewToken(TokenType::kQuestion); Next(); return token; - } else if (MatchString("meta")) { + } else if (MatchString("meta[")) { return TokenizeMetaRef(); } else if (next == '#') { return TokenizeAttr(); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 73b44f7361a57..c78f3abd6eccf 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -168,7 +168,7 @@ class TECompilerImpl : public TECompilerNode { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - // Unfortuantely, Optional() is indistinguishable from + // Unfortunately, Optional() is indistinguishable from // NullValue(), and DictAttrs() is nullptr, so to erase the attributes, we // need pass in DictAttrs()), which is a DictAttrs containing no // attributes. @@ -176,8 +176,8 @@ class TECompilerImpl : public TECompilerNode { WithFields(GetRef(function_node), function_node->params, function_node->body, function_node->ret_type, function_node->type_params, /* erase attributes */ DictAttrs(Map())); - // Mark function as 'extern' using the "ExternalSymbol" attribute. - function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); + // Mark function as 'extern'. + function = WithAttr(std::move(function), attr::kExtern, Integer(1)); module->Add(kv2.first, function); } } @@ -688,7 +688,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { if (function_node->HasNonzeroAttr(attr::kPrimitive) || - function_node->GetAttr(attr::kExternalSymbol)) { + function_node->HasNonzeroAttr(attr::kExtern)) { // Nothing to lower inside primitive/external functions. return GetRef(function_node); } else { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e0b742a840906..d9730b1b5a4ca 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -922,7 +922,7 @@ void VMCompiler::LowerImpl(IRModule mod) { for (const auto& pair : context_.module->functions) { auto gvar = pair.first; if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kExternalSymbol).defined()) { + if (n->HasNonzeroAttr(attr::kExtern)) { // Already compiled during lowering. continue; } @@ -1131,7 +1131,7 @@ size_t VMCompiler::PopulateGlobalMap() { // Excludes PrimFuncs and externs, which are managed by the primitive_map_. for (const auto& kv : context_.module->functions) { if (const auto* function_node = kv.second.as()) { - if (!function_node->GetAttr(attr::kExternalSymbol)) { + if (!function_node->HasNonzeroAttr(attr::kExtern)) { context_.global_map.emplace(kv.first, context_.global_map.size()); } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index bf0dd577a4d29..63e74144e0616 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -112,7 +112,7 @@ FuncType FunctionNode::func_type_annotation() const { const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { if (const auto* function_node = base_func.as()) { if (!function_node->GetAttr(attr::kCompiler).defined() && - !function_node->GetAttr(attr::kExternalSymbol).defined() && + !function_node->HasNonzeroAttr(attr::kExtern) && !function_node->HasNonzeroAttr(attr::kSkipOptimization)) { return function_node; } diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 234cafdca1502..41b47401de1c2 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1012,6 +1012,7 @@ Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT f - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("tensor_a", "3D Tensor", "The first input.") .add_argument("tensor_b", "3D Tensor", "The second input.") diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc new file mode 100644 index 0000000000000..b98d089b346a3 --- /dev/null +++ b/src/relay/transforms/compiler_function_utils.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/compiler_function_utils.cc + * \brief Helper passes for working with functions with the "Compiler" attribute. + */ + +#include "./compiler_function_utils.h" + +#include "../op/call/call.h" +#include "tvm/relay/analysis.h" +#include "tvm/relay/expr_functor.h" + +namespace tvm { +namespace relay { +namespace transforms { +namespace { + +/*! + * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given + * module will be extended with the newly outlined functions. + */ +class Outliner : public MixedModeMutator { + public: + Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod) + : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* function_node = new_call->op.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { + auto function = GetRef(function_node); + DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + << "' attribute should not have free variables"; + // Ask the cache to supply a unique global var for this function. + GlobalVar global_symbol = cache_->GetGlobalSymbol(function); + // Depending on the cache's implementation, two structurally equal (but not object equal) + // functions may be assigned the same global symbol. If so we'll lift it just once, but + // rewrite all the calls. + if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { + function = + WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); + mod_->Add(global_symbol, function); + } + // Update the call. + return WithFields(new_call, global_symbol); + } + } + return post; + } + + private: + /*! + * \brief A cached mapping from functions to global variables. Depending on the implementation + * the cache may generate fresh symbols or require the function to already have a "global_symbol" + * attribute, and may share symbols between structurally equal functions. + */ + GlobalSymbolCache* cache_; + /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ + std::string compiler_filter_; + /*! \brief Module being rewritten. */ + IRModule mod_; +}; + +/*! + * \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention. + */ +class CallRewriter : public MixedModeMutator { + public: + CallRewriter(std::string compiler_filter, IRModule mod) + : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* global_var_node = new_call->op.as()) { + if (const auto* function_node = + mod_->Lookup(GetRef(global_var_node)).as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { + Optional opt_global_symbol = + function_node->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()); + GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value()); + CallLoweredAttrs attrs; + attrs.metadata.Set("relay_attrs", new_call->attrs); + return CallLowered(global_symbol, new_call->args, attrs, new_call->span); + } + } + } + return post; + } + + private: + /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ + std::string compiler_filter_; + /*! \brief Module being rewritten. */ + IRModule mod_; +}; + +} // namespace + +GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "ExistingGlobalSymbolCache requires all functions to already have a '" + << tvm::attr::kGlobalSymbol << "' attribute"; + std::string global_symbol = opt_global_symbol.value(); + auto itr = global_vars_.find(global_symbol); + if (itr != global_vars_.end()) { + return itr->second; + } + // Ok if function does not have a checked_type, but if it does capture it in the global var. + GlobalVar global_var(global_symbol, function->checked_type_, function->span); + global_vars_.emplace(global_symbol, global_var); + return global_var; +} + +transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter) { + runtime::TypedPackedFunc pass_func = + [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( + IRModule mod, transform::PassContext ctx) { + IRModule output_mod = GetRef(mod.CopyOnWrite()); + for (const auto& kv : mod->functions) { + const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second); + if (function_node) { + Expr new_body = + Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Add(kv.first, new_function); + } + } + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "OutlineCompilerFunctions", {}); +} + +// Any Java programmers in the house? +transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) { + return OutlineCompilerFunctions(std::make_shared(), + std::move(compiler_filter)); +} + +transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { + runtime::TypedPackedFunc pass_func = + [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { + IRModule output_mod = mod->ShallowCopy(); + + // First pass, rewrite the calls. + // We have to do this before marking functions as 'extern' to know which calls to rewrite! + for (const auto& kv : mod->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + Expr new_body = + CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Update(kv.first, new_function); + } + } + + // Second pass, mark functions as 'extern'. + for (const auto& kv : mod->functions) { + if (const auto* function_node = kv.second.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { + auto new_function = WithFields( + GetRef(function_node), function_node->params, function_node->body, + function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); + new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); + output_mod->Update(kv.first, new_function); + } + } + } + + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") + .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); +TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") + .set_body_typed(MarkCompilerFunctionsAsExtern); + +} // namespace transforms +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h new file mode 100644 index 0000000000000..7b5143444bf8a --- /dev/null +++ b/src/relay/transforms/compiler_function_utils.h @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/compiler_function_utils.h + * \brief Helper passes for working with functions with the "Compiler" attribute. + * + * Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external + * codegen may find the following two helper passes useful: + * + * - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a + * matching "Compiler" attribute to be global functions, using the "global_symbol" attribute + * already assigned. Can be used before custom lowering. + * + * Note that ideally "Compiler" attributed functions would be made global functions as early as + * possible and would stay that way. However, the GraphExecutorCodegen and AOTExecutorCodegen + * assume the entire model can be represented by a single 'main' function, and the Inline pass + * is run to respect that assumption. So this pass is mostly just to undo that Pass after modules + * have passed through the 'codegen' keyhole. + * + * See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc. + * + * - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom + * cache to both allocate "global_symbol" names and ensure two strucurally equal functions are + * assigned the same name, and thus lowered only once. This is used by Collage when preparing + * the optimally partitioned IRModule). + * + * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" + * attribute with the same function with just an "Extern" attribute, signalling the function + * has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered' + * calling convention. Can be used after lowering to cleanup the IRModule. + * + * Note that the above behaviour is hard coded within the TECompiler, but is only available to + * external codegen using the Function-at-a-time "relay.ext.toolchain" extension point. + */ + +#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ +#define TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ + +#include +#include +#include + +#include "tvm/ir/transform.h" +#include "tvm/relay/function.h" + +namespace tvm { +namespace relay { +namespace transforms { + +/*! + * \brief Abstract class representing a cache of unique global vars keyed by functions. This can + * be used to ensure structurally equal functions are assigned the same global var object, and + * thus lowered at most once. + */ +class GlobalSymbolCache { + public: + virtual GlobalVar GetGlobalSymbol(const Function& function) = 0; +}; + +/*! + * \brief A \p GlobalSymbolCache that requires every "Compiler" attributed function to already + * have a "global_symbol" attribute. + */ +class ExistingGlobalSymbolCache : public GlobalSymbolCache { + public: + ExistingGlobalSymbolCache() = default; + + GlobalVar GetGlobalSymbol(const Function& function) final; + + private: + /*! \brief Maps already seen global symbol names to their corresponding GlobalVar objects. */ + std::unordered_map global_vars_; +}; + +/*! + * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" + * attribute. The given \p GlobalSymbolCache is used to determine a unique global symbol for each + * function, which is also assigned to the "global_symbol" attribute of the new global function. + * + * At most one function with the same global symbol is outlined. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + */ +transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter = ""); + +/*! + * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" + * attribute. The functions are bound to unique global vars according to their existing + * "global_symbol" attribute. At most one function with the same global symbol is outlined. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + * + * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism + * to prepare the IRModule before custom lowering. + */ +transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = ""); + +/*! + * \brief A pass to mark all global functions which have a "Compiler" attribute matching + * compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and + * rewrite all calls to such functions to use the 'call_lowered' calling convention. + * + * If \p compiler_filter is non-empty only functions with that as their attribute value are + * outlined. + * + * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to + * cleanup the IRModule after custom lowering. + */ +transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); + +} // namespace transforms +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 45cb8271b0746..18d2de1bdede3 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -84,7 +84,7 @@ class PurityVisitor : ExprFunctor { for (const auto& kv : mod_->functions) { if (const auto* function_node = kv.second.as()) { if (function_node->HasNonzeroAttr(attr::kPrimitive) || - function_node->GetAttr(attr::kExternalSymbol)) { + function_node->HasNonzeroAttr(attr::kExtern)) { // Ignore primitive and external functions. continue; } @@ -133,9 +133,11 @@ class PurityVisitor : ExprFunctor { Purity VisitExpr_(const GlobalVarNode* global_var_node) final { auto global_var = GetRef(global_var_node); + ICHECK(mod_->ContainGlobalVar(global_var_node->name_hint)) + << "No definition for '" << global_var_node->name_hint << "'"; auto func = mod_->Lookup(global_var); if (const auto* function_node = func.as()) { - if (!function_node->GetAttr(attr::kExternalSymbol)) { + if (!function_node->HasNonzeroAttr(attr::kExtern)) { return VisitGlobalFunction(global_var, GetRef(function_node)); } } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index c55b6778093e5..012b3579494f1 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -110,7 +110,7 @@ class Inliner : ExprMutator { if (!function_node->body.defined()) return false; // The function must be annotated with the inline attribute. - // (Note that external functions do not have this attribute!) + // (Note that partitioned functions and external functions do not have this attribute!) if (!function_node->HasNonzeroAttr(attr::kInline)) return false; // The function is not able to be inlined if any callee under the CallGraph @@ -136,8 +136,7 @@ class Inliner : ExprMutator { auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined() && - !func->GetAttr(attr::kExternalSymbol).defined()) { + if (!func->GetAttr(attr::kCompiler).defined() && !func->HasNonzeroAttr(attr::kExtern)) { ICHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py new file mode 100644 index 0000000000000..13e0f98e79f19 --- /dev/null +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License +"""Unit tests for the OutlineCompilerFunctionsWithExistingGlobalSymbols and + MarkCompilerFunctionsAsExtern external codegen helper passes.""" + +import tvm +import tvm.testing +import numpy as np + + +def make_const(dtype, shape): + return tvm.relay.const(np.random.rand(*shape).astype(dtype)) + + +def make_consts(dtype, shapes): + return [make_const(dtype, shape) for shape in shapes] + + +metatable = { + "relay.Constant": make_consts( + "float16", + [ + (2304, 768), # 0 + (2304,), # 1 + (600, 32, 64), # 2 + ], + ), + "attributes": [{"relay_attrs": None}], +} + + +def inlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %0 = fn(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + }; + %1 = %0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + +def expected_outlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + + def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + } + """, + "from_string", + None, + metatable, + ) + + +def expected_extern_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0, meta[relay.Constant][0], meta[relay.Constant][1]), metadata=meta[attributes][0]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + + def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Extern=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + } + """, + "from_string", + None, + metatable, + ) + + +def test_outline_compiler_functions_with_existing_global_symbols(): + actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( + "cutlass" + )(inlined_mod()) + tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) + + +def test_mark_compiler_functions_as_extern(): + actual_extern_mod = tvm.relay.transform.MarkCompilerFunctionsAsExtern("cutlass")( + expected_outlined_mod() + ) + tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True) + + +if __name__ == "__main__": + tvm.testing.main() From 4811d702f3cadf5b06d7c1947846b10b90b19e79 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 3 Jun 2022 15:23:32 -0500 Subject: [PATCH 035/181] [Hexagon] Register strategy for concatenate (#11562) * [Hexagon] Register strategy for concatenate * Restart CI --- python/tvm/relay/op/strategy/hexagon.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index da15a5412517d..be01ee50fba82 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -26,7 +26,7 @@ @batch_matmul_strategy.register("hexagon") -def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): +def batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): """batch_matmul strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( @@ -37,6 +37,18 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): return strategy +@concatenate_strategy.register("hexagon") +def concatenate_strategy_hexagon(attrs, inputs, out_type, target): + """concatenate strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.concatenate), + wrap_topi_schedule(topi.hexagon.schedule_injective), + name="concatenate.hexagon", + ) + return strategy + + @conv2d_strategy.register("hexagon") def conv2d_strategy_hexagon(attrs, inputs, out_type, target): """Conv2d strategy for Hexagon""" From cee74c9f8f5563b1bed1956acccd6027d530d45e Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Sat, 4 Jun 2022 02:02:09 +0530 Subject: [PATCH 036/181] [CI] Update to LLVM 14.0.0 for ci_hexagon (#11539) --- docker/install/ubuntu_install_hexagon.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/ubuntu_install_hexagon.sh b/docker/install/ubuntu_install_hexagon.sh index 46d2a44cfaa52..e616c8a4977cc 100755 --- a/docker/install/ubuntu_install_hexagon.sh +++ b/docker/install/ubuntu_install_hexagon.sh @@ -21,9 +21,9 @@ set -o pipefail # Install LLVM/clang CLANG_LLVM_HOME=/opt/clang-llvm -CLANG_LLVM_VERSION=13.0.0 +CLANG_LLVM_VERSION=14.0.0 CLANG_LLVM_FILENAME=clang_llvm.tar.xz -wget -q https://github.com/llvm/llvm-project/releases/download/llvmorg-${CLANG_LLVM_VERSION}/clang+llvm-${CLANG_LLVM_VERSION}-x86_64-linux-gnu-ubuntu-16.04.tar.xz -O ${CLANG_LLVM_FILENAME} +wget -q https://github.com/llvm/llvm-project/releases/download/llvmorg-${CLANG_LLVM_VERSION}/clang+llvm-${CLANG_LLVM_VERSION}-x86_64-linux-gnu-ubuntu-18.04.tar.xz -O ${CLANG_LLVM_FILENAME} mkdir ${CLANG_LLVM_HOME} tar -xvf ${CLANG_LLVM_FILENAME} -C ${CLANG_LLVM_HOME} --strip-components=1 rm ${CLANG_LLVM_FILENAME} From b885362c36eff6d08363d53e5816f696a99ac822 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Jun 2022 16:03:08 -0500 Subject: [PATCH 037/181] [CI] Refactor of tvm.testing.requires_* annotations (#11313) * [CI] Improved skip messages when using @tvm.testing.requires_* Previously, the same message was given regardless of why a test couldn't be run. This has been split up into separate checks for TVM cmake options in `config.cmake`, enabled targets in `TVM_TEST_TARGETS` environment variable, and checks for available hardware. * Refactor to specify repeated feature marks, compile-only markers * Fixed lint errors * Import from contrib, not from a different import * Removed use of requires_llvm() as a list of marks * Corrected mark from requires_gpu to requires_cuda * Adding missing "not" * Added USE_CMSISNN as a requirement for corstone300. --- python/tvm/testing/plugin.py | 25 +- python/tvm/testing/utils.py | 799 ++++++++++++---------- tests/python/contrib/test_dnnl.py | 4 +- tests/python/contrib/test_tensorrt.py | 4 +- tests/python/driver/tvmc/test_compiler.py | 12 +- tests/python/integration/test_reduce.py | 2 +- 6 files changed, 463 insertions(+), 383 deletions(-) diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index e90bd5e6dbf52..1f4f983b72102 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -56,8 +56,8 @@ def pytest_configure(config): """Runs at pytest configure time, defines marks to be used later.""" - for markername, desc in MARKERS.items(): - config.addinivalue_line("markers", "{}: {}".format(markername, desc)) + for feature in utils.Feature._all_features.values(): + feature._register_marker(config) print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets()))) print("pytest marker:", config.option.markexpr) @@ -269,25 +269,26 @@ def _target_to_requirement(target): # mapping from target to decorator if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): - return utils.requires_cudnn() + return utils.requires_cudnn.marks() if target.kind.name == "cuda" and "cublas" in target.attrs.get("libs", []): - return utils.requires_cublas() + return utils.requires_cublas.marks() if target.kind.name == "cuda": - return utils.requires_cuda() + return utils.requires_cuda.marks() if target.kind.name == "rocm": - return utils.requires_rocm() + return utils.requires_rocm.marks() if target.kind.name == "vulkan": - return utils.requires_vulkan() + return utils.requires_vulkan.marks() if target.kind.name == "nvptx": - return utils.requires_nvptx() + return utils.requires_nvptx.marks() if target.kind.name == "metal": - return utils.requires_metal() + return utils.requires_metal.marks() if target.kind.name == "opencl": - return utils.requires_opencl() + return utils.requires_opencl.marks() if target.kind.name == "llvm": - return utils.requires_llvm() + return utils.requires_llvm.marks() if target.kind.name == "hexagon": - return utils.requires_hexagon() + return utils.requires_hexagon.marks() + return [] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 0e2d7be4a14e7..939786c9294fc 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -67,15 +67,20 @@ def test_something(): import copyreg import ctypes import functools +import itertools import logging import os +import pickle import platform import shutil import sys import time -import pickle + +from typing import Optional, Callable, Union, List + import pytest import numpy as np + import tvm import tvm.arith import tvm.tir @@ -84,9 +89,6 @@ def test_something(): from tvm.contrib import nvcc, cudnn from tvm.error import TVMError -from tvm.relay.op.contrib.ethosn import ethosn_available -from tvm.relay.op.contrib import cmsisnn -from tvm.relay.op.contrib import vitis_ai SKIP_SLOW_TESTS = os.getenv("SKIP_SLOW_TESTS", "").lower() in {"true", "1", "yes"} @@ -388,12 +390,9 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): ) -def _get_targets(target_str=None): - if target_str is None: - target_str = os.environ.get("TVM_TEST_TARGETS", "") - # Use dict instead of set for de-duplication so that the - # targets stay in the order specified. - target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()}) +def _get_targets(target_names=None): + if target_names is None: + target_names = _tvm_test_targets() if not target_names: target_names = DEFAULT_TEST_TARGETS @@ -429,7 +428,7 @@ def _get_targets(target_str=None): " Try setting TVM_TEST_TARGETS to a supported target. Defaulting to llvm.", target_str, ) - return _get_targets("llvm") + return _get_targets(["llvm"]) raise TVMError( "None of the following targets are supported by this build of TVM: %s." @@ -515,458 +514,544 @@ def enabled_targets(): return [(t["target"], tvm.device(t["target"])) for t in _get_targets() if t["is_runnable"]] -def _compose(args, decs): - """Helper to apply multiple markers""" - if len(args) > 0: - f = args[0] - for d in reversed(decs): - f = d(f) - return f - return decs +class Feature: + """A feature that may be required to run a test. -def slow(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if SKIP_SLOW_TESTS: - pytest.skip("Skipping slow test since RUN_SLOW_TESTS environment variables is 'true'") - else: - fn(*args, **kwargs) + Parameters + ---------- + name: str - return wrapper + The short name of the feature. Should match the name in the + requires_* decorator. This is applied as a mark to all tests + using this feature, and can be used in pytests ``-m`` + argument. + long_name: Optional[str] -def uses_gpu(*args): - """Mark to differentiate tests that use the GPU in some capacity. + The long name of the feature, to be used in error messages. - These tests will be run on CPU-only test nodes and on test nodes with GPUs. - To mark a test that must have a GPU present to run, use - :py:func:`tvm.testing.requires_gpu`. + If None, defaults to the short name. - Parameters - ---------- - f : function - Function to mark - """ - _uses_gpu = [pytest.mark.gpu] - return _compose(args, _uses_gpu) + cmake_flag: Optional[str] + The flag that must be enabled in the config.cmake in order to + use this feature. -def requires_x86(*args): - """Mark a test as requiring the x86 Architecture to run. + If None, no flag is required to use this feature. - Tests with this mark will not be run unless on an x86 platform. + target_kind_enabled: Optional[str] - Parameters - ---------- - f : function - Function to mark - """ - _requires_x86 = [ - pytest.mark.skipif(platform.machine() != "x86_64", reason="x86 Architecture Required"), - ] - return _compose(args, _requires_x86) + The target kind that must be enabled to run tests using this + feature. If present, the target_kind must appear in the + TVM_TEST_TARGETS environment variable, or in + tvm.testing.DEFAULT_TEST_TARGETS if TVM_TEST_TARGETS is + undefined. + If None, this feature does not require a specific target to be + enabled. -def requires_gpu(*args): - """Mark a test as requiring a GPU to run. + compile_time_check: Optional[Callable[[], Union[bool,str]]] - Tests with this mark will not be run unless a gpu is present. + A check that returns True if the feature can be used at + compile-time. (e.g. Validating the version number of the nvcc + compiler.) If the feature does not have support to perform + compile-time tests, the check should returns False to display + a generic error message, or a string to display a more + specific error message. - Parameters - ---------- - f : function - Function to mark - """ - _requires_gpu = [ - pytest.mark.skipif( - not tvm.cuda().exist - and not tvm.rocm().exist - and not tvm.opencl().exist - and not tvm.metal().exist - and not tvm.vulkan().exist, - reason="No GPU present", - ), - *uses_gpu(), - ] - return _compose(args, _requires_gpu) + If None, no additional check is performed. + target_kind_hardware: Optional[str] -def requires_cuda(*args): - """Mark a test as requiring the CUDA runtime. + The target kind that must have available hardware in order to + run tests using this feature. This is checked using + tvm.device(target_kind_hardware).exist. If a feature requires + a different check, this should be implemented using + run_time_check. - This also marks the test as requiring a cuda gpu. + If None, this feature does not require a specific + tvm.device to exist. - Parameters - ---------- - f : function - Function to mark - """ - _requires_cuda = [ - pytest.mark.cuda, - pytest.mark.skipif(not device_enabled("cuda"), reason="CUDA support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_cuda) + run_time_check: Optional[Callable[[], Union[bool,str]]] + A check that returns True if the feature can be used at + run-time. (e.g. Validating the compute version supported by a + GPU.) If the feature does not have support to perform + run-time tests, the check should returns False to display a + generic error message, or a string to display a more specific + error message. -def requires_cudnn(*args): - """Mark a test as requiring the cuDNN library. + If None, no additional check is performed. - This also marks the test as requiring a cuda gpu. + parent_features: Optional[Union[str,List[str]]] - Parameters - ---------- - f : function - Function to mark - """ + The short name of a feature or features that are required in + order to use this feature. (e.g. Using cuDNN requires using + CUDA) This feature should inherit all checks of the parent + feature, with the exception of the `target_kind_enabled` + checks. - requirements = [ - pytest.mark.skipif( - not cudnn.exists(), reason="cuDNN library not enabled, or not installed" - ), - *requires_cuda(), - ] - return _compose(args, requirements) + If None, this feature does not require any other parent + features. + """ -def requires_cublas(*args): - """Mark a test as requiring the cuBLAS library. + _all_features = {} + + def __init__( + self, + name: str, + long_name: Optional[str] = None, + cmake_flag: Optional[str] = None, + target_kind_enabled: Optional[str] = None, + compile_time_check: Optional[Callable[[], Union[bool, str]]] = None, + target_kind_hardware: Optional[str] = None, + run_time_check: Optional[Callable[[], Union[bool, str]]] = None, + parent_features: Optional[Union[str, List[str]]] = None, + ): + self.name = name + self.long_name = long_name or name + self.cmake_flag = cmake_flag + self.target_kind_enabled = target_kind_enabled + self.compile_time_check = compile_time_check + self.target_kind_hardware = target_kind_hardware + self.run_time_check = run_time_check + + if parent_features is None: + self.parent_features = [] + elif isinstance(parent_features, str): + self.parent_features = [parent_features] + else: + self.parent_features = parent_features - This also marks the test as requiring a cuda gpu. + self._all_features[self.name] = self - Parameters - ---------- - f : function - Function to mark - """ + def _register_marker(self, config): + config.addinivalue_line("markers", f"{self.name}: Mark a test as using {self.long_name}") - requirements = [ - pytest.mark.skipif( - tvm.get_global_func("tvm.contrib.cublas.matmul", True), - reason="cuDNN library not enabled", - ), - *requires_cuda(), - ] - return _compose(args, requirements) + def _uses_marks(self): + for parent in self.parent_features: + yield from self._all_features[parent]._uses_marks() + yield getattr(pytest.mark, self.name) -def requires_nvptx(*args): - """Mark a test as requiring the NVPTX compilation on the CUDA runtime + def _compile_only_marks(self): + for parent in self.parent_features: + yield from self._all_features[parent]._compile_only_marks() - This also marks the test as requiring a cuda gpu, and requiring - LLVM support. + if self.compile_time_check is not None: + res = self.compile_time_check() + if isinstance(res, str): + yield pytest.mark.skipif(True, reason=res) + else: + yield pytest.mark.skipif( + not res, reason=f"Compile-time support for {self.long_name} not present" + ) - Parameters - ---------- - f : function - Function to mark + if self.target_kind_enabled is not None: + target_kind = self.target_kind_enabled.split()[0] + yield pytest.mark.skipif( + all(enabled.split()[0] != target_kind for enabled in _tvm_test_targets()), + reason=( + f"{self.target_kind_enabled} tests disabled " + f"by TVM_TEST_TARGETS environment variable" + ), + ) - """ - _requires_nvptx = [ - pytest.mark.skipif(not device_enabled("nvptx"), reason="NVPTX support not enabled"), - *requires_llvm(), - *requires_gpu(), - ] - return _compose(args, _requires_nvptx) + if self.cmake_flag is not None: + yield pytest.mark.skipif( + not _cmake_flag_enabled(self.cmake_flag), + reason=( + f"{self.long_name} support not enabled. " + f"Set {self.cmake_flag} in config.cmake to enable." + ), + ) + def _run_only_marks(self): + for parent in self.parent_features: + yield from self._all_features[parent]._run_only_marks() + + if self.run_time_check is not None: + res = self.run_time_check() + if isinstance(res, str): + yield pytest.mark.skipif(True, reason=res) + else: + yield pytest.mark.skipif( + not res, reason=f"Run-time support for {self.long_name} not present" + ) -def requires_nvcc_version(major_version, minor_version=0, release_version=0): - """Mark a test as requiring at least a specific version of nvcc. + if self.target_kind_hardware is not None: + yield pytest.mark.skipif( + not tvm.device(self.target_kind_hardware).exist, + reason=f"No device exists for target {self.target_kind_hardware}", + ) - Unit test marked with this decorator will run only if the - installed version of NVCC is at least `(major_version, - minor_version, release_version)`. + def marks(self, support_required="compile-and-run"): + """Return a list of marks to be used - This also marks the test as requiring a cuda support. + Parameters + ---------- - Parameters - ---------- - major_version: int + support_required: str - The major version of the (major,minor,release) version tuple. + Allowed values: "compile-and-run" (default), + "compile-only", or "optional". - minor_version: int + See Feature.__call__ for details. + """ + if support_required not in ["compile-and-run", "compile-only", "optional"]: + raise ValueError(f"Unknown feature support type: {support_required}") - The minor version of the (major,minor,release) version tuple. + if support_required == "compile-and-run": + marks = itertools.chain( + self._run_only_marks(), self._compile_only_marks(), self._uses_marks() + ) + elif support_required == "compile-only": + marks = itertools.chain(self._compile_only_marks(), self._uses_marks()) + elif support_required == "optional": + marks = self._uses_marks() + else: + raise ValueError(f"Unknown feature support type: {support_required}") - release_version: int + return list(marks) - The release version of the (major,minor,release) version tuple. + def __call__(self, func=None, *, support_required="compile-and-run"): + """Mark a pytest function as requiring this feature - """ + Can be used either as a bare decorator, or as a decorator with + arguments. - try: - nvcc_version = nvcc.get_cuda_version() - except RuntimeError: - nvcc_version = (0, 0, 0) + Parameters + ---------- - min_version = (major_version, minor_version, release_version) - version_str = ".".join(str(v) for v in min_version) - requires = [ - pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"), - *requires_cuda(), - ] + func: Callable - def inner(func): - return _compose([func], requires) + The pytest test function to be marked - return inner + support_required: str + Allowed values: "compile-and-run" (default), + "compile-only", or "optional". -def skip_if_32bit(reason): - def decorator(*args): - if "32bit" in platform.architecture()[0]: - return _compose(args, [pytest.mark.skip(reason=reason)]) + If "compile-and-run", the test case is marked as using the + feature, and is skipped if the environment lacks either + compile-time or run-time support for the feature. - return _compose(args, []) + If "compile-only", the test case is marked as using the + feature, and is skipped if the environment lacks + compile-time support. - return decorator + If "optional", the test case is marked as using the + feature, but isn't skipped. This is kept for backwards + compatibility for tests that use `enabled_targets()`, and + should be avoided in new test code. Instead, prefer + parametrizing over the target using the `target` fixture. + Examples + -------- -def requires_cudagraph(*args): - """Mark a test as requiring the CUDA Graph Feature + .. code-block:: python - This also marks the test as requiring cuda + @feature + def test_compile_and_run(): + ... - Parameters - ---------- - f : function - Function to mark - """ - _requires_cudagraph = [ - pytest.mark.skipif( - not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment" - ), - *requires_cuda(), - ] - return _compose(args, _requires_cudagraph) + @feature(compile_only=True) + def test_compile_only(): + ... + """ -def requires_opencl(*args): - """Mark a test as requiring the OpenCL runtime. + if support_required not in ["compile-and-run", "compile-only", "optional"]: + raise ValueError(f"Unknown feature support type: {support_required}") - This also marks the test as requiring a gpu. + def wrapper(func): + for mark in self.marks(support_required=support_required): + func = mark(func) + return func - Parameters - ---------- - f : function - Function to mark - """ - _requires_opencl = [ - pytest.mark.opencl, - pytest.mark.skipif(not device_enabled("opencl"), reason="OpenCL support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_opencl) + if func is None: + return wrapper + return wrapper(func) -def requires_corstone300(*args): - """Mark a test as requiring the corstone300 FVP + @classmethod + def require(cls, name, support_required="compile-and-run"): + """Returns a decorator that marks a test as requiring a feature - Parameters - ---------- - f : function - Function to mark - """ - _requires_corstone300 = [ - pytest.mark.corstone300, - pytest.mark.skipif( - shutil.which("arm-none-eabi-gcc") is None, reason="ARM embedded toolchain unavailable" - ), - ] - return _compose(args, _requires_corstone300) + Parameters + ---------- + name: str -def requires_rocm(*args): - """Mark a test as requiring the rocm runtime. + The name of the feature that is used by the test - This also marks the test as requiring a gpu. + support_required: str - Parameters - ---------- - f : function - Function to mark - """ - _requires_rocm = [ - pytest.mark.rocm, - pytest.mark.skipif(not device_enabled("rocm"), reason="rocm support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_rocm) + Allowed values: "compile-and-run" (default), + "compile-only", or "optional". + See Feature.__call__ for details. -def requires_metal(*args): - """Mark a test as requiring the metal runtime. + Examples + -------- - This also marks the test as requiring a gpu. + .. code-block:: python - Parameters - ---------- - f : function - Function to mark - """ - _requires_metal = [ - pytest.mark.metal, - pytest.mark.skipif(not device_enabled("metal"), reason="metal support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_metal) + @Feature.require("cuda") + def test_compile_and_run(): + ... + @Feature.require("cuda", compile_only=True) + def test_compile_only(): + ... + """ + return cls._all_features[name](support_required=support_required) -def requires_vulkan(*args): - """Mark a test as requiring the vulkan runtime. - This also marks the test as requiring a gpu. +def _any_gpu_exists(): + return ( + tvm.cuda().exist + or tvm.rocm().exist + or tvm.opencl().exist + or tvm.metal().exist + or tvm.vulkan().exist + ) - Parameters - ---------- - f : function - Function to mark - """ - _requires_vulkan = [ - pytest.mark.vulkan, - pytest.mark.skipif(not device_enabled("vulkan"), reason="vulkan support not enabled"), - *requires_gpu(), - ] - return _compose(args, _requires_vulkan) +# Mark a test as requiring llvm to run +requires_llvm = Feature( + "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm" +) -def requires_tensorcore(*args): - """Mark a test as requiring a tensorcore to run. +# Mark a test as requiring a GPU to run. +requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists) - Tests with this mark will not be run unless a tensorcore is present. +# Mark to differentiate tests that use the GPU in some capacity. +# +# These tests will be run on CPU-only test nodes and on test nodes with GPUs. +# To mark a test that must have a GPU present to run, use +# :py:func:`tvm.testing.requires_gpu`. +uses_gpu = requires_gpu(support_required="optional") + +# Mark a test as requiring the x86 Architecture to run. +requires_x86 = Feature( + "x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64" +) + +# Mark a test as requiring the CUDA runtime. +requires_cuda = Feature( + "cuda", + "CUDA", + cmake_flag="USE_CUDA", + target_kind_enabled="cuda", + target_kind_hardware="cuda", + parent_features="gpu", +) + +# Mark a test as requiring a tensorcore to run +requires_tensorcore = Feature( + "tensorcore", + "NVIDIA Tensor Core", + run_time_check=lambda: tvm.cuda().exist and nvcc.have_tensorcore(tvm.cuda().compute_version), + parent_features="cuda", +) + +# Mark a test as requiring the cuDNN library. +requires_cudnn = Feature("cudnn", "cuDNN", cmake_flag="USE_CUDNN", parent_features="cuda") + +# Mark a test as requiring the cuBLAS library. +requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS", parent_features="cuda") + +# Mark a test as requiring the NVPTX compilation on the CUDA runtime +requires_nvptx = Feature( + "nvptx", + "NVPTX", + target_kind_enabled="nvptx", + target_kind_hardware="nvptx", + parent_features=["llvm", "cuda"], +) + +# Mark a test as requiring the CUDA Graph Feature +requires_cudagraph = Feature( + "cudagraph", + "CUDA Graph", + target_kind_enabled="cuda", + compile_time_check=nvcc.have_cudagraph, + parent_features="cuda", +) + +# Mark a test as requiring the OpenCL runtime +requires_opencl = Feature( + "opencl", + "OpenCL", + cmake_flag="USE_OPENCL", + target_kind_enabled="opencl", + target_kind_hardware="opencl", + parent_features="gpu", +) + +# Mark a test as requiring the rocm runtime +requires_rocm = Feature( + "rocm", + "ROCm", + cmake_flag="USE_ROCM", + target_kind_enabled="rocm", + target_kind_hardware="rocm", + parent_features="gpu", +) + +# Mark a test as requiring the metal runtime +requires_metal = Feature( + "metal", + "Metal", + cmake_flag="USE_METAL", + target_kind_enabled="metal", + target_kind_hardware="metal", + parent_features="gpu", +) + +# Mark a test as requiring the vulkan runtime +requires_vulkan = Feature( + "vulkan", + "Vulkan", + cmake_flag="USE_VULKAN", + target_kind_enabled="vulkan", + target_kind_hardware="vulkan", + parent_features="gpu", +) + +# Mark a test as requiring microTVM to run +requires_micro = Feature("micro", "MicroTVM", cmake_flag="USE_MICRO") + +# Mark a test as requiring rpc to run +requires_rpc = Feature("rpc", "RPC", cmake_flag="USE_RPC") + +# Mark a test as requiring Arm(R) Ethos(TM)-N to run +requires_ethosn = Feature("ethosn", "Arm(R) Ethos(TM)-N", cmake_flag="USE_ETHOSN") + +# Mark a test as requiring Hexagon to run +requires_hexagon = Feature( + "hexagon", + "Hexagon", + cmake_flag="USE_HEXAGON", + target_kind_enabled="hexagon", + compile_time_check=lambda: ( + (_cmake_flag_enabled("USE_LLVM") and tvm.target.codegen.llvm_version_major() >= 7) + or "Hexagon requires LLVM 7 or later" + ), + target_kind_hardware="hexagon", + parent_features="llvm", +) + +# Mark a test as requiring the CMSIS NN library +requires_cmsisnn = Feature("cmsisnn", "CMSIS NN", cmake_flag="USE_CMSISNN") + +# Mark a test as requiring the corstone300 FVP +requires_corstone300 = Feature( + "corstone300", + "Corstone-300", + compile_time_check=lambda: ( + (shutil.which("arm-none-eabi-gcc") is None) or "ARM embedded toolchain unavailable" + ), + parent_features="cmsisnn", +) + +# Mark a test as requiring Vitis AI to run +requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI") + + +def _cmake_flag_enabled(flag): + flag = tvm.support.libinfo()[flag] + + # Because many of the flags can be library flags, we check if the + # flag is not disabled, rather than checking if it is enabled. + return flag.lower() not in ["off", "false", "0"] + + +def _tvm_test_targets(): + target_str = os.environ.get("TVM_TEST_TARGETS", "").strip() + if target_str: + # Use dict instead of set for de-duplication so that the + # targets stay in the order specified. + return list({t.strip(): None for t in target_str.split(";") if t.strip()}) - Parameters - ---------- - f : function - Function to mark - """ - _requires_tensorcore = [ - pytest.mark.tensorcore, - pytest.mark.skipif( - not tvm.cuda().exist or not nvcc.have_tensorcore(tvm.cuda(0).compute_version), - reason="No tensorcore present", - ), - *requires_gpu(), - ] - return _compose(args, _requires_tensorcore) + return DEFAULT_TEST_TARGETS -def requires_llvm(*args): - """Mark a test as requiring llvm to run. +def _compose(args, decs): + """Helper to apply multiple markers""" + if len(args) > 0: + f = args[0] + for d in reversed(decs): + f = d(f) + return f + return decs - Parameters - ---------- - f : function - Function to mark - """ - _requires_llvm = [ - pytest.mark.llvm, - pytest.mark.skipif(not device_enabled("llvm"), reason="LLVM support not enabled"), - ] - return _compose(args, _requires_llvm) +def slow(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if SKIP_SLOW_TESTS: + pytest.skip("Skipping slow test since RUN_SLOW_TESTS environment variables is 'true'") + else: + fn(*args, **kwargs) -def requires_micro(*args): - """Mark a test as requiring microTVM to run. + return wrapper - Parameters - ---------- - f : function - Function to mark - """ - _requires_micro = [ - pytest.mark.skipif( - tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON", - reason="MicroTVM support not enabled. Set USE_MICRO=ON in config.cmake to enable.", - ) - ] - return _compose(args, _requires_micro) +def requires_nvcc_version(major_version, minor_version=0, release_version=0): + """Mark a test as requiring at least a specific version of nvcc. -def requires_rpc(*args): - """Mark a test as requiring rpc to run. + Unit test marked with this decorator will run only if the + installed version of NVCC is at least `(major_version, + minor_version, release_version)`. + + This also marks the test as requiring a cuda support. Parameters ---------- - f : function - Function to mark - """ - _requires_rpc = [ - pytest.mark.skipif( - tvm.support.libinfo().get("USE_RPC", "OFF") != "ON", - reason="RPC support not enabled. Set USE_RPC=ON in config.cmake to enable.", - ) - ] - return _compose(args, _requires_rpc) + major_version: int + The major version of the (major,minor,release) version tuple. -def requires_ethosn(*args): - """Mark a test as requiring Arm(R) Ethos(TM)-N to run. + minor_version: int - Parameters - ---------- - f : function - Function to mark - """ - marks = [ - pytest.mark.ethosn, - pytest.mark.skipif( - not ethosn_available(), - reason=( - "Arm(R) Ethos(TM)-N support not enabled. " - "Set USE_ETHOSN=ON in config.cmake to enable, " - "and ensure that hardware support is present." - ), - ), - ] - return _compose(args, marks) + The minor version of the (major,minor,release) version tuple. + release_version: int -def requires_hexagon(*args): - """Mark a test as requiring Hexagon to run. + The release version of the (major,minor,release) version tuple. - Parameters - ---------- - f : function - Function to mark """ - _requires_hexagon = [ - pytest.mark.hexagon, - pytest.mark.skipif(not device_enabled("hexagon"), reason="Hexagon support not enabled"), - *requires_llvm(), - pytest.mark.skipif( - tvm.target.codegen.llvm_version_major() < 7, reason="Hexagon requires LLVM 7 or later" - ), - ] - return _compose(args, _requires_hexagon) + try: + nvcc_version = nvcc.get_cuda_version() + except RuntimeError: + nvcc_version = (0, 0, 0) -def requires_cmsisnn(*args): - """Mark a test as requiring the CMSIS NN library. + min_version = (major_version, minor_version, release_version) + version_str = ".".join(str(v) for v in min_version) + requires = [ + pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"), + *requires_cuda.marks(), + ] - Parameters - ---------- - f : function - Function to mark - """ + def inner(func): + return _compose([func], requires) - requirements = [pytest.mark.skipif(not cmsisnn.enabled(), reason="CMSIS NN not enabled")] - return _compose(args, requirements) + return inner -def requires_vitis_ai(*args): - """Mark a test as requiring Vitis AI to run. +def skip_if_32bit(reason): + def decorator(*args): + if "32bit" in platform.architecture()[0]: + return _compose(args, [pytest.mark.skip(reason=reason)]) - Parameters - ---------- - f : function - Function to mark - """ + return _compose(args, []) - requirements = [pytest.mark.skipif(not vitis_ai.enabled(), reason="Vitis AI not enabled")] - return _compose(args, requirements) + return decorator def requires_package(*packages): diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 76e3f1c3a4055..19ac183d66dfe 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -34,8 +34,8 @@ ) run_module = tvm.testing.parameter( - pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]), - pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]), + pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), + pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), ids=["compile", "run"], ) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 982ec976d54ed..cecb64785a49a 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -44,9 +44,9 @@ ) run_module = tvm.testing.parameter( - pytest.param(False, marks=[has_tensorrt_codegen, *tvm.testing.requires_cuda()]), + pytest.param(False, marks=[has_tensorrt_codegen, *tvm.testing.requires_cuda.marks()]), pytest.param( - True, marks=[has_tensorrt_runtime, has_tensorrt_codegen, *tvm.testing.requires_cuda()] + True, marks=[has_tensorrt_runtime, has_tensorrt_codegen, *tvm.testing.requires_cuda.marks()] ), ids=["compile", "run"], ) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index d6ae27957de2d..e8e93a6c7514d 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -25,7 +25,7 @@ import tvm import tvm.testing -from tvm.testing.utils import ethosn_available +from tvm.relay.op.contrib.ethosn import ethosn_available from tvm.relay.backend import Runtime, Executor from tvm.contrib.target.vitis_ai import vitis_ai_available @@ -412,10 +412,7 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( assert len(c_source_files) == 4 -@pytest.mark.skipif( - not ethosn_available(), - reason="--target=Ethos(TM)-N78 is not available. TVM built with 'USE_ETHOSN OFF'", -) +@tvm.testing.requires_ethosn def test_compile_tflite_module_with_external_codegen_ethos_n78(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) @@ -430,10 +427,7 @@ def test_compile_tflite_module_with_external_codegen_ethos_n78(tflite_mobilenet_ assert os.path.exists(dumps_path) -@pytest.mark.skipif( - not vitis_ai_available(), - reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", -) +@tvm.testing.requires_vitis_ai def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index a40164ded941e..f3886374ccb65 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -528,7 +528,7 @@ def check_target(device): check_target("rocm") -@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_reduce_storage_reuse(): target = tvm.target.Target("cuda") From 8823757f3037cdf2afe0ce6bb4f38fff8ef97536 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 3 Jun 2022 16:09:24 -0500 Subject: [PATCH 038/181] [TIR] Expose tir.call_cpacked in python (#11563) --- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 2d201bb0dab65..6db93b6ad0915 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -44,7 +44,7 @@ from .function import PrimFunc, TensorIntrin, IndexMap -from .op import call_packed, call_intrin, call_pure_extern, call_extern +from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index de3ca5fa8d5b2..5d15bf15da581 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -73,6 +73,33 @@ def call_packed(*args, span=None): return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span) +def call_cpacked(*args, span=None): + """Build expression by call an external packed function. + + Same as call_packed, except that the first argument is the function name + (as in call_extern), and the last argument is the resource handle. + + Parameters + ---------- + args : list of Expr or Buffer. + Positional arguments. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + te.extern : Create tensor with extern function call. + """ + call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] + return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span) + + def call_intrin(dtype, func_name, *args, span=None): """Build expression by calling an intrinsic function. From 6dbdf2e20116ecc6f5379f5cb430ed023ff0d62b Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Fri, 3 Jun 2022 14:22:05 -0700 Subject: [PATCH 039/181] Fix Hexagon build using ci.py (#11304) * Add output directory add post build for hexagon fix -net=host for docker * remove --net by default --- tests/scripts/ci.py | 6 +++++- tests/scripts/task_build_hexagon_api.sh | 16 ++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index b3f9cb6500e53..599bbaddceec9 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -342,6 +342,7 @@ def generate_command( options: Dict[str, Option], help: str, precheck: Optional[Callable[[], None]] = None, + post_build: Optional[List[str]] = None, ): """ Helper to generate CLIs that: @@ -378,6 +379,9 @@ def fn( f"./tests/scripts/task_build.py --build-dir {get_build_dir(name)}", ] + if post_build is not None: + scripts += post_build + # Check that a test suite was not used alongside specific test names if any(v for v in kwargs.values()) and tests is not None: option_flags = ", ".join([f"--{k}" for k in options.keys()]) @@ -624,12 +628,12 @@ def add_subparser( generate_command( name="hexagon", help="Run Hexagon build and test(s)", + post_build=["./tests/scripts/task_build_hexagon_api.sh --output build-hexagon"], options={ "cpp": CPP_UNITTEST, "test": ( "run Hexagon API/Python tests", [ - "./tests/scripts/task_build_hexagon_api.sh", "./tests/scripts/task_python_hexagon.sh", ], ), diff --git a/tests/scripts/task_build_hexagon_api.sh b/tests/scripts/task_build_hexagon_api.sh index 4c7b4f396ced4..5f811e4e27492 100755 --- a/tests/scripts/task_build_hexagon_api.sh +++ b/tests/scripts/task_build_hexagon_api.sh @@ -19,6 +19,15 @@ set -e set -u +output_directory_parent=$(realpath ${PWD}/build) +if [ $# -ge 1 ] && [[ "$1" == "--output" ]]; then + shift 1 + output_directory_parent=$(realpath $1) + shift 1 +fi +output_directory="${output_directory_parent}/hexagon_api_output" +rm -rf ${output_directory} + use_cache=false if [ $# -ge 1 ] && [[ "$1" == "--use-cache" ]]; then use_cache=true @@ -26,24 +35,19 @@ if [ $# -ge 1 ] && [[ "$1" == "--use-cache" ]]; then fi cd apps/hexagon_api - if [ "$use_cache" = false ]; then rm -rf build fi - mkdir -p build cd build -output_binary_directory=$(realpath ${PWD}/../../../build/hexagon_api_output) -rm -rf ${output_binary_directory} - cmake -DANDROID_ABI=arm64-v8a \ -DANDROID_PLATFORM=android-28 \ -DUSE_ANDROID_TOOLCHAIN="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \ -DUSE_HEXAGON_ARCH=v68 \ -DUSE_HEXAGON_SDK="${HEXAGON_SDK_ROOT}" \ -DUSE_HEXAGON_TOOLCHAIN="${HEXAGON_TOOLCHAIN}" \ - -DUSE_OUTPUT_BINARY_DIR="${output_binary_directory}" \ + -DUSE_OUTPUT_BINARY_DIR="${output_directory}" \ -DUSE_HEXAGON_GTEST="${HEXAGON_SDK_ROOT}/utils/googletest/gtest" .. make -j$(nproc) From f05ebde8e84e4bce620b0fdf839b89eb60c1008c Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Fri, 3 Jun 2022 16:23:46 -0600 Subject: [PATCH 040/181] [docs] microTVM model training tutorial with Colab support (#10921) * First draft of micro train tutorial * unit test code * Fix obvious formatting issues * Linting * Proof of concept showing that "Open in Colab" is possible * Make test Python script more readable * Fix formatting * Ready for review * Import pyserial only when needed Changes from code review Use official sphinx-gallery repo Correctly specify version Import pyserial only when necessary * Add warning to ignored list Try to avoid throwing warning Fix linting, try verbosity filter Try adding to ignore file Remove fix attempts * Grammar fixes * Address code review comments Include full git hashes * Rerun tests * Rerun again --- .../template_project/microtvm_api_server.py | 4 +- apps/microtvm/pyproject.toml | 2 +- docker/install/ubuntu_install_sphinx.sh | 2 +- docs/conf.py | 3 +- .../how_to/work_with_microtvm/micro_train.py | 649 ++++++++++++++++++ tests/scripts/ci.py | 3 +- tests/scripts/task_python_docs.sh | 2 + 7 files changed, 660 insertions(+), 5 deletions(-) create mode 100644 gallery/how_to/work_with_microtvm/micro_train.py diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 95f941fe34737..131f92a208298 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -34,7 +34,6 @@ import re from packaging import version -import serial.tools.list_ports from tvm.micro.project_api import server @@ -485,6 +484,9 @@ def flash(self, options): subprocess.run(upload_cmd, check=True) def open_transport(self, options): + import serial + import serial.tools.list_ports + # Zephyr example doesn't throw an error in this case if self._serial is not None: return diff --git a/apps/microtvm/pyproject.toml b/apps/microtvm/pyproject.toml index 98c769be48f51..5976328592290 100644 --- a/apps/microtvm/pyproject.toml +++ b/apps/microtvm/pyproject.toml @@ -129,7 +129,7 @@ importer-tflite = ["tflite", "tensorflow", "tensorflow-estimator"] autodocsumm = "^0.1" black = "^19.10b0" sphinx = "^3.0" -sphinx-gallery = "^0.8" +sphinx-gallery = { git = "https://github.com/sphinx-gallery/sphinx-gallery.git", rev = "6142f179" } sphinx-rtd-theme = "^0.4" matplotlib = "^3.2" Image = "^1.5" diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 12ca25b22b85a..96023fa6e633a 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -29,5 +29,5 @@ pip3 install \ matplotlib \ sphinx==4.2.0 \ sphinx_autodoc_annotation \ - sphinx-gallery==0.4.0 \ + "git+https://github.com/sphinx-gallery/sphinx-gallery.git@6142f1791151849b5bec4bf3959f75697ba226cd" \ sphinx_rtd_theme diff --git a/docs/conf.py b/docs/conf.py index 49c5c4fa755d2..9d55e20c03e5c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -312,6 +312,7 @@ def git_describe_version(original_version): "bring_your_own_datatypes.py", ], "micro": [ + "micro_train.py", "micro_autotune.py", "micro_reference_vm.py", "micro_tflite.py", @@ -360,11 +361,11 @@ def force_gc(gallery_conf, fname): "gallery_dirs": gallery_dirs, "subsection_order": subsection_order, "filename_pattern": os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", ".py"), - "find_mayavi_figures": False, "download_all_examples": False, "min_reported_time": 60, "expected_failing_examples": [], "reset_modules": ("matplotlib", "seaborn", force_gc), + "promote_jupyter_magic": True, } autodoc_default_options = { diff --git a/gallery/how_to/work_with_microtvm/micro_train.py b/gallery/how_to/work_with_microtvm/micro_train.py new file mode 100644 index 0000000000000..378fe56d9da01 --- /dev/null +++ b/gallery/how_to/work_with_microtvm/micro_train.py @@ -0,0 +1,649 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _microtvm-train-arduino: + +Training Vision Models for microTVM on Arduino +============================================== +**Author**: `Gavin Uberti `_ + +This tutorial shows how MobileNetV1 models can be trained +to fit on embedded devices, and how those models can be +deployed to Arduino using TVM. +""" + +###################################################################### +# .. note:: +# +# This tutorial is best viewed as a Jupyter Notebook. You can download and run it locally +# using the link at the bottom of this page, or open it online for free using Google Colab. +# Click the icon below to open in Google Colab. +# +# .. image:: https://raw.githubusercontent.com/guberti/web-data/micro-train-tutorial-data/images/utilities/colab_button.png +# :align: center +# :target: https://colab.research.google.com/github/guberti/tvm-site/blob/asf-site/docs/_downloads/a7c7ea4b5017ae70db1f51dd8e6dcd82/micro_train.ipynb +# :width: 300px +# +# Motivation +# ---------- +# When building IOT devices, we often want them to **see and understand** the world around them. +# This can take many forms, but often times a device will want to know if a certain **kind of +# object** is in its field of vision. +# +# For example, a security camera might look for **people**, so it can decide whether to save a video +# to memory. A traffic light might look for **cars**, so it can judge which lights should change +# first. Or a forest camera might look for a **kind of animal**, so they can estimate how large +# the animal population is. +# +# To make these devices affordable, we would like them to need only a low-cost processor like the +# `nRF52840 `_ (costing five dollars each on Mouser) or the `RP2040 `_ (just $1.45 each!). +# +# These devices have very little memory (~250 KB RAM), meaning that no conventional edge AI +# vision model (like MobileNet or EfficientNet) will be able to run. In this tutorial, we will +# show how these models can be modified to work around this requirement. Then, we will use TVM +# to compile and deploy it for an Arduino that uses one of these processors. +# +# Installing the Prerequisites +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# This tutorial will use TensorFlow to train the model - a widely used machine learning library +# created by Google. TensorFlow is a very low-level library, however, so we will the Keras +# interface to talk to TensorFlow. We will also use TensorFlow Lite to perform quantization on +# our model, as TensorFlow by itself does not support this. +# +# Once we have our generated model, we will use TVM to compile and test it. To avoid having to +# build from source, we'll install ``tlcpack`` - a community build of TVM. Lastly, we'll also +# install ``imagemagick`` and ``curl`` to preprocess data: +# +# .. code-block:: bash +# +# %%bash +# pip install -q tensorflow tflite +# pip install -q tlcpack-nightly -f https://tlcpack.ai/wheels +# apt-get -qq install imagemagick curl +# +# # Install Arduino CLI and library for Nano 33 BLE +# curl -fsSL https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh +# /content/bin/arduino-cli core update-index +# /content/bin/arduino-cli core install arduino:mbed_nano +# +# Using the GPU +# ^^^^^^^^^^^^^ +# +# This tutorial demonstrates training a neural network, which is requires a lot of computing power +# and will go much faster if you have a GPU. If you are viewing this tutorial on Google Colab, you +# can enable a GPU by going to **Runtime->Change runtime type** and selecting "GPU" as our hardware +# accelerator. If you are running locally, you can `follow TensorFlow's guide `_ instead. +# +# We can test our GPU installation with the following code: + +import tensorflow as tf + +if not tf.test.gpu_device_name(): + print("No GPU was detected!") + print("Model training will take much longer (~30 minutes instead of ~5)") +else: + print("GPU detected - you're good to go.") + +###################################################################### +# Choosing Our Work Dir +# ^^^^^^^^^^^^^^^^^^^^^ +# We need to pick a directory where our image datasets, trained model, and eventual Arduino sketch +# will all live. If running on Google Colab, we'll save everything in ``/root`` (aka ``~``) but you'll +# probably want to store it elsewhere if running locally. Note that this variable only affects Python +# scripts - you'll have to adjust the Bash commands too. + +import os + +FOLDER = "/root" +# sphinx_gallery_start_ignore +import tempfile + +FOLDER = tempfile.mkdtemp() +# sphinx_gallery_end_ignore + +###################################################################### +# Downloading the Data +# -------------------- +# Convolutional neural networks usually learn by looking at many images, along with labels telling +# the network what those images are. To get these images, we'll need a publicly available dataset +# with thousands of images of all sorts of objects and labels of what's in each image. We'll also +# need a bunch of images that **aren't** of cars, as we're trying to distinguish these two classes. +# +# In this tutorial, we'll create a model to detect if an image contains a **car**, but you can use +# whatever category you like! Just change the source URL below to one containing images of another +# type of object. +# +# To get our car images, we'll be downloading the `Stanford Cars dataset `_, +# which contains 16,185 full color images of cars. We'll also need images of random things that +# aren't cars, so we'll use the `COCO 2017 `_ validation set (it's +# smaller, and thus faster to download than the full training set. Training on the full data set +# would yield better results). Note that there are some cars in the COCO 2017 data set, but it's +# a small enough fraction not to matter - just keep in mind that this will drive down our percieved +# accuracy slightly. +# +# We could use the TensorFlow dataloader utilities, but we'll instead do it manually to make sure +# it's easy to change the datasets being used. We'll end up with the following file hierarchy: +# +# .. code-block:: +# +# /root +# ├── images +# │ ├── object +# │ │ ├── 000001.jpg +# │ │ │ ... +# │ │ └── 016185.jpg +# │ ├── object.tgz +# │ ├── random +# │ │ ├── 000000000139.jpg +# │ │ │ ... +# │ │ └── 000000581781.jpg +# │ └── random.zip +# +# We should also note that Stanford cars has 8k images, while the COCO 2017 validation set is 5k +# images - it is not a 50/50 split! If we wanted to, we could weight these classes differently +# during training to correct for this, but training will still work if we ignore it. It should +# take about **2 minutes** to download the Stanford Cars, while COCO 2017 validation will take +# **1 minute**. + +import os +import shutil +import urllib.request + +# Download datasets +os.makedirs(f"{FOLDER}/images") +urllib.request.urlretrieve( + "http://ai.stanford.edu/~jkrause/car196/cars_train.tgz", f"{FOLDER}/images/target.tgz" +) +urllib.request.urlretrieve( + "http://images.cocodataset.org/zips/val2017.zip", f"{FOLDER}/images/random.zip" +) + +# Extract them and rename their folders +shutil.unpack_archive(f"{FOLDER}/images/target.tgz", f"{FOLDER}/images") +shutil.unpack_archive(f"{FOLDER}/images/random.zip", f"{FOLDER}/images") +shutil.move(f"{FOLDER}/images/cars_train", f"{FOLDER}/images/target") +shutil.move(f"{FOLDER}/images/val2017", f"{FOLDER}/images/random") + +###################################################################### +# Loading the Data +# ---------------- +# Currently, our data is stored on-disk as JPG files of various sizes. To train with it, we'll have +# to load the images into memory, resize them to be 64x64, and convert them to raw, uncompressed +# data. Keras's ``image_dataset_from_directory`` will take care of most of this, though it loads +# images such that each pixel value is a float from 0 to 255. +# +# We'll also need to load labels, though Keras will help with this. From our subdirectory structure, +# it knows the images in ``/objects`` are one class, and those in ``/random`` another. Setting +# ``label_mode='categorical'`` tells Keras to convert these into **categorical labels** - a 2x1 vector +# that's either ``[1, 0]`` for an object of our target class, or ``[0, 1]`` vector for anything else. +# We'll also set ``shuffle=True`` to randomize the order of our examples. +# +# We will also **batch** the data - grouping samples into clumps to make our training go faster. +# Setting ``batch_size = 32`` is a decent number. +# +# Lastly, in machine learning we generally want our inputs to be small numbers. We'll thus use a +# ``Rescaling`` layer to change our images such that each pixel is a float between ``0.0`` and ``1.0``, +# instead of ``0`` to ``255``. We need to be careful not to rescale our categorical labels though, so +# we'll use a ``lambda`` function. + +IMAGE_SIZE = (64, 64, 3) +unscaled_dataset = tf.keras.utils.image_dataset_from_directory( + f"{FOLDER}/images", + batch_size=32, + shuffle=True, + label_mode="categorical", + image_size=IMAGE_SIZE[0:2], +) +rescale = tf.keras.layers.Rescaling(scale=1.0 / 255) +full_dataset = unscaled_dataset.map(lambda im, lbl: (rescale(im), lbl)) + +###################################################################### +# What's Inside Our Dataset? +# ^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Before giving this data set to our neural network, we ought to give it a quick visual inspection. +# Does the data look properly transformed? Do the labels seem appropriate? And what's our ratio of +# objects to other stuff? We can display some examples from our datasets using ``matplotlib``: + +import matplotlib.pyplot as plt + +num_target_class = len(os.listdir(f"{FOLDER}/images/target/")) +num_random_class = len(os.listdir(f"{FOLDER}/images/random/")) +print(f"{FOLDER}/images/target contains {num_target_class} images") +print(f"{FOLDER}/images/random contains {num_random_class} images") + +# Show some samples and their labels +SAMPLES_TO_SHOW = 10 +plt.figure(figsize=(20, 10)) +for i, (image, label) in enumerate(unscaled_dataset.unbatch()): + if i >= SAMPLES_TO_SHOW: + break + ax = plt.subplot(1, SAMPLES_TO_SHOW, i + 1) + plt.imshow(image.numpy().astype("uint8")) + plt.title(list(label.numpy())) + plt.axis("off") + +###################################################################### +# Validating our Accuracy +# ^^^^^^^^^^^^^^^^^^^^^^^ +# While developing our model, we'll often want to check how accurate it is (e.g. to see if it +# improves during training). How do we do this? We could just train it on *all* of the data, and +# then ask it to classify that same data. However, our model could cheat by just memorizing all of +# the samples, which would make it *appear* to have very high accuracy, but perform very badly in +# reality. In practice, this "memorizing" is called **overfitting**. +# +# To prevent this, we will set aside some of the data (we'll use 20%) as a **validation set**. Our +# model will never be trained on validation data - we'll only use it to check our model's accuracy. + +num_batches = len(full_dataset) +train_dataset = full_dataset.take(int(num_batches * 0.8)) +validation_dataset = full_dataset.skip(len(train_dataset)) + +###################################################################### +# Loading the Data +# ---------------- +# In the past decade, `convolutional neural networks `_ have been widely +# adopted for image classification tasks. State-of-the-art models like `EfficientNet V2 `_ are able +# to perform image classification better than even humans! Unfortunately, these models have tens of +# millions of parameters, and thus won't fit on cheap security camera computers. +# +# Our applications generally don't need perfect accuracy - 90% is good enough. We can thus use the +# older and smaller MobileNet V1 architecture. But this *still* won't be small enough - by default, +# MobileNet V1 with 224x224 inputs and alpha 1.0 takes ~50 MB to just **store**. To reduce the size +# of the model, there are three knobs we can turn. First, we can reduce the size of the input images +# from 224x224 to 96x96 or 64x64, and Keras makes it easy to do this. We can also reduce the **alpha** +# of the model, from 1.0 to 0.25, which downscales the width of the network (and the number of +# filters) by a factor of four. And if we were really strapped for space, we could reduce the +# number of **channels** by making our model take grayscale images instead of RGB ones. +# +# In this tutorial, we will use an RGB 64x64 input image and alpha 0.25. This is not quite +# ideal, but it allows the finished model to fit in 192 KB of RAM, while still letting us perform +# transfer learning using the official TensorFlow source models (if we used alpha <0.25 or a +# grayscale input, we wouldn't be able to do this). +# +# What is Transfer Learning? +# ^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Deep learning has `dominated image classification `_ for a long time, +# but training neural networks takes a lot of time. When a neural network is trained "from scratch", +# its parameters start out randomly initialized, forcing it to learn very slowly how to tell images +# apart. +# +# With transfer learning, we instead start with a neural network that's **already** good at a +# specific task. In this example, that task is classifying images from `the ImageNet database `_. This +# means the network already has some object detection capabilities, and is likely closer to what you +# want then a random model would be. +# +# This works especially well with image processing neural networks like MobileNet. In practice, it +# turns out the convolutional layers of the model (i.e. the first 90% of the layers) are used for +# identifying low-level features like lines and shapes - only the last few fully connected layers +# are used to determine how those shapes make up the objects the network is trying to detect. +# +# We can take advantage of this by starting training with a MobileNet model that was trained on +# ImageNet, and already knows how to identify those lines and shapes. We can then just remove the +# last few layers from this pretrained model, and add our own final layers. We'll then train this +# conglomerate model for a few epochs on our cars vs non-cars dataset, to adjust the first layers +# and train from scratch the last layers. This process of training an already-partially-trained +# model is called *fine-tuning*. +# +# Source MobileNets for transfer learning have been `pretrained by the TensorFlow folks `_, so we +# can just download the one closest to what we want (the 128x128 input model with 0.25 depth scale). + +os.makedirs(f"{FOLDER}/models") +WEIGHTS_PATH = f"{FOLDER}/models/mobilenet_2_5_128_tf.h5" +urllib.request.urlretrieve( + "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_2_5_128_tf.h5", + WEIGHTS_PATH, +) + +pretrained = tf.keras.applications.MobileNet( + input_shape=IMAGE_SIZE, weights=WEIGHTS_PATH, alpha=0.25 +) + +###################################################################### +# Modifying Our Network +# ^^^^^^^^^^^^^^^^^^^^^ +# As mentioned above, our pretrained model is designed to classify the 1,000 ImageNet categories, +# but we want to convert it to classify cars. Since only the bottom few layers are task-specific, +# we'll **cut off the last five layers** of our original model. In their place we'll build our own +# "tail" to the model by performing respape, dropout, flatten, and softmax operations. + +model = tf.keras.models.Sequential() + +model.add(tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE)) +model.add(tf.keras.Model(inputs=pretrained.inputs, outputs=pretrained.layers[-5].output)) + +model.add(tf.keras.layers.Reshape((-1,))) +model.add(tf.keras.layers.Dropout(0.1)) +model.add(tf.keras.layers.Flatten()) +model.add(tf.keras.layers.Dense(2, activation="softmax")) + +###################################################################### +# Fine Tuning Our Network +# ^^^^^^^^^^^^^^^^^^^^^^^ +# When training neural networks, we must set a parameter called the **learning rate** that controls +# how fast our network learns. It must be set carefully - too slow, and our network will take +# forever to train; too fast, and our network won't be able to learn some fine details. Generally +# for Adam (the optimizer we're using), ``0.001`` is a pretty good learning rate (and is what's +# recommended in the `original paper `_). However, in this case +# ``0.0005`` seems to work a little better. +# +# We'll also pass the validation set from earlier to ``model.fit``. This will evaluate how good our +# model is each time we train it, and let us track how our model is improving. Once training is +# finished, the model should have a validation accuracy around ``0.98`` (meaning it was right 98% of +# the time on our validation set). + +model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), + loss="categorical_crossentropy", + metrics=["accuracy"], +) +model.fit(train_dataset, validation_data=validation_dataset, epochs=3, verbose=2) + +###################################################################### +# Quantization +# ------------ +# We've done a decent job of reducing our model's size so far - changing the input dimension, +# along with removing the bottom layers reduced the model to just 219k parameters. However, each of +# these parameters is a ``float32`` that takes four bytes, so our model will take up almost one MB! +# +# Additionally, it might be the case that our hardware doesn't have built-in support for floating +# point numbers. While most high-memory Arduinos (like the Nano 33 BLE) do have hardware support, +# some others (like the Arduino Due) do not. On any boards *without* dedicated hardware support, +# floating point multiplication will be extremely slow. +# +# To address both issues we will **quantize** the model - representing the weights as eight bit +# integers. It's more complex than just rounding, though - to get the best performance, TensorFlow +# tracks how each neuron in our model activates, so we can figure out how most accurately simulate +# the neuron's original activations with integer operations. +# +# We will help TensorFlow do this by creating a representative dataset - a subset of the original +# that is used for tracking how those neurons activate. We'll then pass this into a ``TFLiteConverter`` +# (Keras itself does not have quantization support) with an ``Optimize`` flag to tell TFLite to perform +# the conversion. By default, TFLite keeps the inputs and outputs of our model as floats, so we must +# explicitly tell it to avoid this behavior. + + +def representative_dataset(): + for image_batch, label_batch in full_dataset.take(10): + yield [image_batch] + + +converter = tf.lite.TFLiteConverter.from_keras_model(model) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +converter.representative_dataset = representative_dataset +converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] +converter.inference_input_type = tf.uint8 +converter.inference_output_type = tf.uint8 + +quantized_model = converter.convert() + +###################################################################### +# Download the Model if Desired +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We've now got a finished model that you can use locally or in other tutorials (try autotuning +# this model or viewing it on `https://netron.app/ `_). But before we do +# those things, we'll have to write it to a file (``quantized.tflite``). If you're running this +# tutorial on Google Colab, you'll have to uncomment the last two lines to download the file +# after writing it. + +QUANTIZED_MODEL_PATH = f"{FOLDER}/models/quantized.tflite" +with open(QUANTIZED_MODEL_PATH, "wb") as f: + f.write(quantized_model) +# from google.colab import files +# files.download(QUANTIZED_MODEL_PATH) + +###################################################################### +# Compiling With TVM For Arduino +# ------------------------------ +# TensorFlow has a built-in framework for deploying to microcontrollers - `TFLite Micro `_. However, +# it's poorly supported by development boards and does not support autotuning. We will use Apache +# TVM instead. +# +# TVM can be used either with its command line interface (``tvmc``) or with its Python interface. The +# Python interface is fully-featured and more stable, so we'll use it here. +# +# TVM is an optimizing compiler, and optimizations to our model are performed in stages via +# **intermediate representations**. The first of these is `Relay `_ a high-level intermediate +# representation emphasizing portability. The conversion from ``.tflite`` to Relay is done without any +# knowledge of our "end goal" - the fact we intend to run this model on an Arduino. +# +# Choosing an Arduino Board +# ^^^^^^^^^^^^^^^^^^^^^^^^^ +# Next, we'll have to decide exactly which Arduino board to use. The Arduino sketch that we +# ultimately generate should be compatible with any board, but knowing which board we are using in +# advance allows TVM to adjust its compilation strategy to get better performance. +# +# There is one catch - we need enough **memory** (flash and RAM) to be able to run our model. We +# won't ever be able to run a complex vision model like a MobileNet on an Arduino Uno - that board +# only has 2 kB of RAM and 32 kB of flash! Our model has ~200,000 parameters, so there is just no +# way it could fit. +# +# For this tutorial, we will use the Nano 33 BLE, which has 1 MB of flash memory and 256 KB of RAM. +# However, any other Arduino with those specs or better should also work. +# +# Generating our project +# ^^^^^^^^^^^^^^^^^^^^^^ +# Next, we'll compile the model to TVM's MLF (model library format) intermediate representation, +# which consists of C/C++ code and is designed for autotuning. To improve performance, we'll tell +# TVM that we're compiling for the ``nrf52840`` microprocessor (the one the Nano 33 BLE uses). We'll +# also tell it to use the C runtime (abbreviated ``crt``) and to use ahead-of-time memory allocation +# (abbreviated ``aot``, which helps reduce the model's memory footprint). Lastly, we will disable +# vectorization with ``"tir.disable_vectorize": True``, as C has no native vectorized types. +# +# Once we have set these configuration parameters, we will call ``tvm.relay.build`` to compile our +# Relay model into the MLF intermediate representation. From here, we just need to call +# ``tvm.micro.generate_project`` and pass in the Arduino template project to finish compilation. + +import shutil +import tflite +import tvm + +# Method to load model is different in TFLite 1 vs 2 +try: # TFLite 2.1 and above + tflite_model = tflite.Model.GetRootAsModel(quantized_model, 0) +except AttributeError: # Fall back to TFLite 1.14 method + tflite_model = tflite.Model.Model.GetRootAsModel(quantized_model, 0) + +# Convert to the Relay intermediate representation +mod, params = tvm.relay.frontend.from_tflite(tflite_model) + +# Set configuration flags to improve performance +target = tvm.target.target.micro("nrf52840") +runtime = tvm.relay.backend.Runtime("crt") +executor = tvm.relay.backend.Executor("aot", {"unpacked-api": True}) + +# Convert to the MLF intermediate representation +with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(mod, target, runtime=runtime, executor=executor, params=params) + +# Generate an Arduino project from the MLF intermediate representation +shutil.rmtree(f"{FOLDER}/models/project", ignore_errors=True) +arduino_project = tvm.micro.generate_project( + tvm.micro.get_microtvm_template_projects("arduino"), + mod, + f"{FOLDER}/models/project", + { + "arduino_board": "nano33ble", + "arduino_cli_cmd": "/content/bin/arduino-cli", + "project_type": "example_project", + }, +) + +###################################################################### +# Testing our Arduino Project +# --------------------------- +# Consider the following two 224x224 images from the author's camera roll - one of a car, one not. +# We will test our Arduino project by loading both of these images and executing the compiled model +# on them. +# +# .. image:: https://raw.githubusercontent.com/guberti/web-data/micro-train-tutorial-data/testdata/microTVM/data/model_train_images_combined.png +# :align: center +# :height: 200px +# :width: 600px +# +# Currently, these are 224x224 PNG images we can download from Imgur. Before we can feed in these +# images, we'll need to resize and convert them to raw data, which can be done with ``imagemagick``. +# +# It's also challenging to load raw data onto an Arduino, as only C/CPP files (and similar) are +# compiled. We can work around this by embedding our raw data in a hard-coded C array with the +# built-in utility ``bin2c`` that will output a file like below: +# +# .. code-block:: c +# +# static const unsigned char CAR_IMAGE[] = { +# 0x22,0x23,0x14,0x22, +# ... +# 0x07,0x0e,0x08,0x08 +# }; +# +# We can do both of these things with a few lines of Bash code: +# +# .. code-block:: bash +# +# %%bash +# mkdir -p ~/tests +# curl "https://i.imgur.com/JBbEhxN.png" -o ~/tests/car_224.png +# convert ~/tests/car_224.png -resize 64 ~/tests/car_64.png +# stream ~/tests/car_64.png ~/tests/car.raw +# bin2c -c -st ~/tests/car.raw --name CAR_IMAGE > ~/models/project/car.c +# +# curl "https://i.imgur.com/wkh7Dx2.png" -o ~/tests/catan_224.png +# convert ~/tests/catan_224.png -resize 64 ~/tests/catan_64.png +# stream ~/tests/catan_64.png ~/tests/catan.raw +# bin2c -c -st ~/tests/catan.raw --name CATAN_IMAGE > ~/models/project/catan.c + +###################################################################### +# Writing our Arduino Script +# -------------------------- +# We now need a little bit of Arduino code to read the two binary arrays we just generated, run the +# model on them, and log the output to the serial monitor. This file will replace ``arduino_sketch.ino`` +# as the main file of our sketch. You'll have to copy this code in manually.. +# +# .. code-block:: c +# +# %%writefile /root/models/project.ino +# #include "src/model.h" +# #include "car.c" +# #include "catan.c" +# +# void setup() { +# Serial.begin(9600); +# TVMInitialize(); +# } +# +# void loop() { +# uint8_t result_data[2]; +# Serial.println("Car results:"); +# TVMExecute(const_cast(CAR_IMAGE), result_data); +# Serial.print(result_data[0]); Serial.print(", "); +# Serial.print(result_data[1]); Serial.println(); +# +# Serial.println("Other object results:"); +# TVMExecute(const_cast(CATAN_IMAGE), result_data); +# Serial.print(result_data[0]); Serial.print(", "); +# Serial.print(result_data[1]); Serial.println(); +# +# delay(1000); +# } +# +# Compiling Our Code +# ^^^^^^^^^^^^^^^^^^ +# Now that our project has been generated, TVM's job is mostly done! We can still call +# ``arduino_project.build()`` and ``arduino_project.upload()``, but these just use ``arduino-cli``'s +# compile and flash commands underneath. We could also begin autotuning our model, but that's a +# subject for a different tutorial. To finish up, we'll verify no compiler errors are thrown +# by our project: + +shutil.rmtree(f"{FOLDER}/models/project/build", ignore_errors=True) +# sphinx_gallery_start_ignore +from unittest.mock import MagicMock + +arduino_project = MagicMock() +# sphinx_gallery_end_ignore +arduino_project.build() +print("Compilation succeeded!") + +###################################################################### +# Uploading to Our Device +# ----------------------- +# The very last step is uploading our sketch to an Arduino to make sure our code works properly. +# Unfortunately, we can't do that from Google Colab, so we'll have to download our sketch. This is +# simple enough to do - we'll just turn our project into a `.zip` archive, and call `files.download`. +# If you're running on Google Colab, you'll have to uncomment the last two lines to download the file +# after writing it. + +ZIP_FOLDER = f"{FOLDER}/models/project" +shutil.make_archive(ZIP_FOLDER, "zip", ZIP_FOLDER) +# from google.colab import files +# files.download(f"{FOLDER}/models/project.zip") +# sphinx_gallery_start_ignore +# Run a few unit tests to make sure the Python code worked + +# Ensure transfer learn model was correctly assembled +assert len(model.layers) == 5 +assert model.count_params() == 219058 # Only 219,058 of these are trainable + +assert len(quantized_model) >= 250000 # Quantized model will be 250 KB - 350 KB +assert len(quantized_model) <= 350000 # Exact value depends on quantization + +# Assert .tflite and .zip files were written to disk +assert os.path.isfile(f"{FOLDER}/models/quantized.tflite") +assert os.path.isfile(f"{FOLDER}/models/project.zip") + +# Assert MLF file was correctly generated +assert str(mod.executor) == "aot" + +# Remove the temporary folder we generated at the beginning +shutil.rmtree(FOLDER) +# sphinx_gallery_end_ignore + + +###################################################################### +# From here, we'll need to open it in the Arduino IDE. You'll have to download the IDE as well as +# the SDK for whichever board you are using. For certain boards like the Sony SPRESENSE, you may +# have to change settings to control how much memory you want the board to use. +# +# Expected Results +# ^^^^^^^^^^^^^^^^ +# If all works as expected, you should see the following output on a Serial monitor: +# +# .. code-block:: +# +# Car results: +# 255, 0 +# Other object results: +# 0, 255 +# +# The first number represents the model's confidence that the object **is** a car and ranges from +# 0-255. The second number represents the model's confidence that the object **is not** a car and +# is also 0-255. These results mean the model is very sure that the first image is a car, and the +# second image is not (which is correct). Hence, our model is working! +# +# Summary +# ------- +# In this tutorial, we used transfer learning to quickly train an image recognition model to +# identify cars. We modified its input dimensions and last few layers to make it better at this, +# and to make it faster and smaller. We then quantified the model and compiled it using TVM to +# create an Arduino sketch. Lastly, we tested the model using two static images to prove it works +# as intended. +# +# Next Steps +# ^^^^^^^^^^ +# From here, we could modify the model to read live images from the camera - we have another +# Arduino tutorial for how to do that `on GitHub `_. Alternatively, we could also +# `use TVM's autotuning capabilities `_ to dramatically improve the model's performance. +# diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 599bbaddceec9..1ffd2d20e7ae9 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -260,7 +260,8 @@ def docs( "tlcpack-sphinx-addon==0.2.1", "synr==0.5.0", "image==1.5.33", - "sphinx-gallery==0.4.0", + # Temporary git link until a release is published + "git+https://github.com/sphinx-gallery/sphinx-gallery.git@6142f1791151849b5bec4bf3959f75697ba226cd", "sphinx-rtd-theme==1.0.0", "matplotlib==3.3.4", "commonmark==0.9.1", diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index b4b52ed36ccf1..da1a2c9c5636a 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -84,6 +84,8 @@ IGNORED_WARNINGS=( 'autotvm:Cannot find config for target=llvm -keys=cpu -link-params=0' 'autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.' 'autotvm:Cannot find config for target=cuda -keys=cuda,gpu' + # Warning is thrown during TFLite quantization for micro_train tutorial + 'absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.' ) JOINED_WARNINGS=$(join_by '|' "${IGNORED_WARNINGS[@]}") From fe24fa9840500b9217f5773e65a764a16e998a66 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Jun 2022 01:37:23 -0700 Subject: [PATCH 041/181] [Bugfix][MetaSchedule] Auto-bind when there are no spatial loops (#11570) --- src/meta_schedule/schedule_rule/auto_bind.cc | 38 +++++++++++----- ...t_meta_schedule_schedule_rule_auto_bind.py | 45 ++++++++++++++++++- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 9c16856557e00..61f8e4f6fc54f 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -72,7 +72,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, if (i_multi_child == -1) { i_multi_child = n; } - if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { + if (i_block_idx != -1 && i_thread_idx != -1) { return; } if (i_block_idx != -1 && i_thread_idx == -1) { @@ -80,16 +80,34 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, throw; } LoopRV loop_rv{nullptr}; - if (i_block_idx == -1 && i_thread_idx != -1) { - int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + { Array loop_rvs = sch->GetLoops(block_rv); - loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); - sch->Bind(loop_rv, "blockIdx.x"); - return; - } else { // i_block_idx == -1 && i_thread_idx == -1 - Array loop_rvs = sch->GetLoops(block_rv); - int num_fuse = std::min(i_multi_child, i_spatial_loop + 1); - loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + if (i_spatial_loop == -1) { + Array split = sch->Split(loop_rvs[0], {Integer(1), NullOpt}); + ICHECK_EQ(split.size(), 2); + loop_rvs.Set(0, split[1]); + loop_rvs.insert(loop_rvs.begin(), split[0]); + i_spatial_loop = 0; + if (i_block_idx != -1) { + i_block_idx += 1; + } + if (i_thread_idx != -1) { + i_thread_idx += 1; + } + if (i_multi_child != -1) { + i_multi_child += 1; + } + } + if (i_block_idx == -1 && i_thread_idx != -1) { + int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + Array loop_rvs = sch->GetLoops(block_rv); + loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + sch->Bind(loop_rv, "blockIdx.x"); + return; + } else { // i_block_idx == -1 && i_thread_idx == -1 + int num_fuse = std::min(i_multi_child, i_spatial_loop + 1); + loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + } } int64_t extent = -1; if (const int64_t* e = GetLoopIntExtent(sch->Get(loop_rv).get())) { diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index bd0a24e8b642e..80a72a4e93ab2 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -20,8 +20,8 @@ from tvm.meta_schedule.testing.schedule_rule import auto_bind from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext -from tvm.target import Target from tvm.script import tir as T +from tvm.target import Target @T.prim_func @@ -34,6 +34,25 @@ def element_wise(var_A: T.handle, var_B: T.handle) -> None: B[vi, vj] = A[vi, vj] + 1.0 +@T.prim_func +def reduction_loop_only( + A: T.Buffer[2, "float32"], + B: T.Buffer[2, "float32"], + C: T.Buffer[(), "float32"], +) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + for i0 in T.serial(2): + with T.block("C"): + k0 = T.axis.reduce(2, i0) + T.reads(A[k0], B[k0]) + T.writes(C[()]) + with T.init(): + C[()] = T.float32(1.0) + C[()] = T.min(C[()], A[k0] / B[k0]) + + def _create_context(mod, target, rule) -> TuneContext: ctx = TuneContext( mod=mod, @@ -71,5 +90,29 @@ def test_cuda_element_wise(): check_trace(spaces, expected) +def test_cuda_reduction_loop_only(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, = sch.get_loops(block=b0)", + "l2, l3 = sch.split(loop=l1, factors=[1, None])", + "l4 = sch.fuse(l2)", + "l5, l6 = sch.split(loop=l4, factors=[None, 1])", + 'sch.bind(loop=l5, thread_axis="blockIdx.x")', + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + ] + ] + target = Target("nvidia/geforce-rtx-3080", host="llvm") + ctx = _create_context( + reduction_loop_only, + target=target, + rule=auto_bind(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + if __name__ == "__main__": test_cuda_element_wise() + test_cuda_reduction_loop_only() From 9d2c9a7f6457fb98156a722625c95bf3383dec42 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Jun 2022 17:48:19 -0700 Subject: [PATCH 042/181] [TIR] Schedule Primitive: Add-Unit-Loop (#11575) In TE, a unit loop could be introduced by fusing an empty list of loops on a stage. This PR adds its counterpart in TIR, while being a bit more explicit with a new schedule primitive which adds a unit loop without impacting any existing functionalities. --- include/tvm/tir/schedule/schedule.h | 12 ++++ python/tvm/tir/schedule/schedule.py | 64 +++++++++++++++-- src/tir/schedule/concrete_schedule.cc | 18 +++++ src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 10 +++ .../schedule/primitive/loop_transformation.cc | 69 +++++++++++++++++++ src/tir/schedule/schedule.cc | 12 ++++ src/tir/schedule/traced_schedule.cc | 22 ++++++ src/tir/schedule/traced_schedule.h | 2 + .../unittest/test_tir_schedule_split_fuse.py | 58 ++++++++++++++++ 10 files changed, 265 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 68900e107d7c9..d3ecd8a1135b8 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -303,6 +303,18 @@ class ScheduleNode : public runtime::Object { * \param ordered_loop_rvs The loops in the new order */ virtual void Reorder(const Array& ordered_loop_rvs) = 0; + /*! + * \brief Create a new unit loop on top of the specific block. + * \param block_rv The block above which the new loop is created + * \return The new loop created + */ + virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0; + /*! + * \brief Create a new unit loop on top of the specific loop. + * \param loop_rv The loop above which the new loop is created + * \return The new loop created + */ + virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0; /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4179088aa534d..d225280b655f7 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,19 +15,19 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -from typing import Callable, Dict, List, Optional, Union, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer -from ..function import IndexMap +from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc +from ..function import IndexMap from . import _ffi_api +from ._type_checker import type_checked from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod from .trace import Trace -from ._type_checker import type_checked @register_error @@ -685,6 +685,62 @@ def after_reorder(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member + @type_checked + def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV: + """Create a new unit loop on top of the specific block or loop. + + Parameters + ---------- + block_or_loop : Union[LoopRV, BlockRV] + The block above which the new loop is created + + Returns + ------- + new_loop : LoopRV + The new unit loop + + Examples + -------- + + Before add_unit_loop, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_add_unit_loop( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + Create the schedule and do add-unit-loop: + + .. code-block:: python + + sch = tir.Schedule(before_add_unit_loop) + sch.add_unit_loop(sch.get_block("C")) + print(sch.mod["main"].script()) + + After applying add-unit-loop, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_add_unit_loop( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u in T.serial(1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + """ + return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop) # type: ignore # pylint: disable=no-member + ########## Schedule: Manipulate ForKind ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 590a0f0025954..051bd42506252 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -453,6 +453,24 @@ void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { this->state_->DebugVerify(); } +LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { + LoopRV result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + +LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { + LoopRV result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::AddUnitLoop(state_, GetSRef(loop_rv))); + TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + /******** Schedule: Manipulate ForKind ********/ void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 70c0265611c31..11d68694a1fec 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -99,6 +99,8 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; void Reorder(const Array& ordered_loop_rvs) override; + LoopRV AddUnitLoop(const BlockRV& block_rv) override; + LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; void Vectorize(const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index f4dba69c6b156..af0f417e4cf50 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -186,6 +186,16 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); */ TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); +/*! + * \brief Create a new unit loop on top of the specific block or loop. + * \param sref The block/loop above which the new thread_binding loop is created + * \param extent The extent of the new thread_binding loop + * \param thread_axis The thread axis of the new thread_binding loop + * \param attrs Extra loop attributes + * \return The new thread_binding loop + */ +TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref); + /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 5315b139f0f6f..66e29518ca5e1 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -698,6 +698,43 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { self->Replace(GetRef(top), new_loop, {}); } +StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { + if (sref->stmt->IsInstance()) { + For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(sref->stmt)); + self->Replace(sref, new_loop, {}); + return self->stmt2ref.at(new_loop.get()); + } + class NewLoopCreator : public StmtMutator { + public: + explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {} + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block.get() == src_block_) { + new_loop_ = + For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(realize)); + return new_loop_; + } + return StmtMutator::VisitStmt_(realize); + } + + const StmtNode* src_block_; + For new_loop_{nullptr}; + }; + + CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block"; + StmtSRef parent_sref = GetRef(sref->parent); + NewLoopCreator creator(sref->stmt); + Stmt new_stmt = creator(GetRef(parent_sref->stmt)); + if (new_stmt->IsInstance()) { + self->Replace(parent_sref, std::move(new_stmt), {}); + } else { + Block old_parent_block = GetRef(parent_sref->StmtAs()); + Block new_parent_block = Downcast(new_stmt); + self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}}); + } + return self->stmt2ref.at(creator.new_loop_.get()); +} + /******** InstructionKind Registration ********/ struct SplitTraits : public UnpackedInstTraits { @@ -800,9 +837,41 @@ struct ReorderTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct AddUnitLoopTraits : public UnpackedInstTraits { + static constexpr const char* kName = "AddUnitLoop"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) { + if (const auto* block = rv.as()) { + return sch->AddUnitLoop(GetRef(block)); + } else if (const auto* loop = rv.as()) { + return sch->AddUnitLoop(GetRef(loop)); + } else { + LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block"; + throw; + } + } + + static String UnpackedAsPython(Array outputs, String rv) { + PythonAPICall py("add_unit_loop"); + py.Input("block_or_loop", rv); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); +TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 3880d0b19eeb8..372d94a15025b 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -153,6 +153,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&Sche TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") .set_body_method(&ScheduleNode::Reorder); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") + .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { + if (const auto* loop_rv = rv.as()) { + return self->AddUnitLoop(GetRef(loop_rv)); + } else if (const auto* block_rv = rv.as()) { + return self->AddUnitLoop(GetRef(block_rv)); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + } + }); /******** (FFI) Manipulate ForKind ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") .set_body_method(&ScheduleNode::Parallel); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d2f627edfd11d..95a10e26ac2f8 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -198,6 +198,28 @@ void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { /*outputs=*/{})); } +LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) { + LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + +LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { + LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + /******** Schedule: Manipulate ForKind ********/ void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ba4a4b99cbb2d..25bf3d4871ae7 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -63,6 +63,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; void Reorder(const Array& ordered_loop_rvs) final; + LoopRV AddUnitLoop(const BlockRV& block_rv) final; + LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; void Vectorize(const LoopRV& loop_rv) final; diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 16eef57c4748d..d70748bc8a03d 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -524,5 +524,63 @@ def test_fuse_not_affine(): verify_trace_roundtrip(sch=sch, mod=elementwise_not_affine) +def test_add_unit_loop_above_block(): + @T.prim_func + def zero_dim( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + @T.prim_func + def zero_dim_added( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u in range(1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + sch = tir.Schedule(zero_dim, debug_mask="all") + block = sch.get_block("C") + sch.add_unit_loop(block) + tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) + + +def test_add_unit_loop_above_loop(): + @T.prim_func + def zero_dim( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u in range(1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + @T.prim_func + def zero_dim_added( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u1, u2 in T.grid(1, 1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + sch = tir.Schedule(zero_dim, debug_mask="all") + block = sch.get_block("C") + (loop,) = sch.get_loops(block) + sch.add_unit_loop(loop) + tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main() From ba60788118e7c65c26cb6cf1097a012dd7b647f2 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Jun 2022 21:42:43 -0700 Subject: [PATCH 043/181] [MetaSchedule] Use Add-Unit-Loop in Auto-Bind (#11581) Following #11575, this PR allows CUDA thread binding for TIR programs like ```python @T.prim_func def zero_dim_add( A: T.Buffer[(), "float32"], B: T.Buffer[(), "float32"], C: T.Buffer[(), "float32"], ) -> None: with T.block("C"): vi = T.axis.spatial(1, 0) C[()] = A[()] + B[()] ``` where there is no loop available to be bound to threadIdx/blockIdx. --- src/meta_schedule/schedule_rule/auto_bind.cc | 18 ++++--- ...t_meta_schedule_schedule_rule_auto_bind.py | 47 +++++++++++++++---- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 61f8e4f6fc54f..2bc90f3c2e5cf 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -30,11 +30,12 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, int64_t max_threadblocks, int64_t max_threads_per_block, std::function get_factor) { using namespace tvm::tir; - Array loops = tir::GetLoops(sch->GetSRef(block_rv)); - int n = loops.size(); - if (n == 0) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (block_sref->parent == nullptr) { return; } + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); int i_block_idx = -1; int i_thread_idx = -1; int i_multi_child = -1; @@ -83,10 +84,13 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, { Array loop_rvs = sch->GetLoops(block_rv); if (i_spatial_loop == -1) { - Array split = sch->Split(loop_rvs[0], {Integer(1), NullOpt}); - ICHECK_EQ(split.size(), 2); - loop_rvs.Set(0, split[1]); - loop_rvs.insert(loop_rvs.begin(), split[0]); + LoopRV spatial_loop_rv{nullptr}; + if (loop_rvs.empty()) { + spatial_loop_rv = sch->AddUnitLoop(block_rv); + } else { + spatial_loop_rv = sch->AddUnitLoop(loop_rvs[0]); + } + loop_rvs.insert(loop_rvs.begin(), spatial_loop_rv); i_spatial_loop = 0; if (i_block_idx != -1) { i_block_idx += 1; diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index 80a72a4e93ab2..8b36ec2f462da 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -40,9 +40,6 @@ def reduction_loop_only( B: T.Buffer[2, "float32"], C: T.Buffer[(), "float32"], ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body for i0 in T.serial(2): with T.block("C"): k0 = T.axis.reduce(2, i0) @@ -53,6 +50,17 @@ def reduction_loop_only( C[()] = T.min(C[()], A[k0] / B[k0]) +@T.prim_func +def zero_dim_add( + A: T.Buffer[(), "float32"], + B: T.Buffer[(), "float32"], + C: T.Buffer[(), "float32"], +) -> None: + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + def _create_context(mod, target, rule) -> TuneContext: ctx = TuneContext( mod=mod, @@ -95,11 +103,11 @@ def test_cuda_reduction_loop_only(): [ 'b0 = sch.get_block(name="C", func_name="main")', "l1, = sch.get_loops(block=b0)", - "l2, l3 = sch.split(loop=l1, factors=[1, None])", - "l4 = sch.fuse(l2)", - "l5, l6 = sch.split(loop=l4, factors=[None, 1])", - 'sch.bind(loop=l5, thread_axis="blockIdx.x")', - 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "l2 = sch.add_unit_loop(block_or_loop=l1)", + "l3 = sch.fuse(l2)", + "l4, l5 = sch.split(loop=l3, factors=[None, 1])", + 'sch.bind(loop=l4, thread_axis="blockIdx.x")', + 'sch.bind(loop=l5, thread_axis="threadIdx.x")', ] ] target = Target("nvidia/geforce-rtx-3080", host="llvm") @@ -113,6 +121,29 @@ def test_cuda_reduction_loop_only(): check_trace(spaces, expected) +def test_cuda_zero_dim_add(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1 = sch.add_unit_loop(block_or_loop=b0)", + "l2 = sch.fuse(l1)", + "l3, l4 = sch.split(loop=l2, factors=[None, 1])", + 'sch.bind(loop=l3, thread_axis="blockIdx.x")', + 'sch.bind(loop=l4, thread_axis="threadIdx.x")', + ] + ] + target = Target("nvidia/geforce-rtx-3080", host="llvm") + ctx = _create_context( + zero_dim_add, + target=target, + rule=auto_bind(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + if __name__ == "__main__": test_cuda_element_wise() test_cuda_reduction_loop_only() + test_cuda_zero_dim_add() From c732828d48c872ff358191da2e2087d38278bb81 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Jun 2022 11:17:32 -0700 Subject: [PATCH 044/181] [TIR] Prevent loop binding over-simplification (#11578) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit @vinx13 @jinhongyii and I observe a recent regression on TVM mainline: over-simplification in `Schedule.split` leads to information loss that negatively impacts search space generation. **Impact.** This affects common operators like `softmax` and even simpler reductions. **Example.** Consider splitting a simple reduction loop: ```python @T.prim_func def main( A: T.Buffer[2, "float32"], B: T.Buffer[2, "float32"], C: T.Buffer[(), "float32"], ) -> None: for i in T.serial(2): # <= split `i` into `i_0` and `i_1`, where `i_0` is a trivial loop with T.block("C"): k = T.axis.reduce(2, i) with T.init(): C[()] = T.float32(1) C[()] = T.min(C[()], A[k] / B[k]) ``` Splitting loop `i` by factors `[1, 2]`, we get: ```python @T.prim_func def main( A: T.Buffer[2, "float32"], B: T.Buffer[2, "float32"], C: T.Buffer[(), "float32"], ) -> None: for i_0, i_1 in T.grid(1, 2): with T.block("C"): k = T.axis.reduce(2, i_1) # <= i_0 is not part of the binding, # so the system cannot tell if i_0 is a reduction loop with T.init(): C[()] = T.float32(1) C[()] = T.min(C[()], A[k] / B[k]) ``` In this case, loop `i_0` will be considered as a spatial loop, even it’s the outcome of splitting a reduction loop. However, if we change the factors from `[1, 2]` to `[2, 1]`, loop `i_0` becomes a reduction loop. This means the loop iteration property depends on the loop extent. **Why is it problematic**? MetaSchedule has an assumption: extremely seldomly, a loop extent would impact the iteration property of the loop itself, i.e. no matter the extent is 1 or 2 or anything, the fact that the loop is a reduction loop should rarely change. As an example, `Auto-Bind` finds the outer `k` spatial loops, which are fused together and bound to thread axis. In the trace, the number (`k`) of the outer loops has to be a constant. However, if Auto-Bind thinks there are `k=3` outer loops to fuse during search space generation, where the last loop happens to be a reduction loop with extent 1, as shown below: ```python for spatial_loop_0 in range(...): for spatial_loop_1 in range(...): for reduction_loop in range(1): # <= Auto-Bind mistakes this loop as spatial, because extent==1 ``` During evolutionary search, the extent of reduction_loop will change and become larger than 1. In this case, the binding strategy will consistently fail because it considers fusing `k=3` loops - which means the entire search strategy will fail with almost no valid candidates. Thanks @MasterJH5574 for figuring out the root cause of the issue, and @jinhongyii for valuable pointers to the right fix! --- include/tvm/arith/iter_affine_map.h | 5 +++-- src/arith/iter_affine_map.cc | 6 ++++-- .../schedule/primitive/loop_transformation.cc | 5 +++-- ...edule_postproc_rewrite_cooperative_fetch.py | 2 +- ...st_meta_schedule_schedule_rule_auto_bind.py | 1 - .../unittest/test_tir_schedule_reorder.py | 2 +- .../unittest/test_tir_schedule_split_fuse.py | 12 ++++++------ .../unittest/test_tir_schedule_transform.py | 18 ++++++------------ 8 files changed, 24 insertions(+), 27 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 2c0e5e92997af..6b98d84fdf17e 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -349,11 +349,12 @@ IterMapResult DetectIterMap(const Array& indices, const Map IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, IterMapLevel check_level); + const PrimExpr& input_pred, IterMapLevel check_level, + bool simplify_trivial_iterators = true); /*! * \brief Apply the inverse of the affine transformation to the outputs. diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index cce826fedca64..ace7b7f84441f 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1720,10 +1720,12 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, IterMapLevel check_level) { + const PrimExpr& input_pred, IterMapLevel check_level, + bool simplify_trivial_iterators) { if (!IterRangeSanityCheck(input_iters)) return indices; Analyzer analyzer; - auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer); + auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer, + /*simplify_trivial_iterators=*/simplify_trivial_iterators); Array rewrite = res->indices; if (rewrite.empty()) { diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 66e29518ca5e1..e374d1f3c5e77 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -115,7 +115,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, - /*check_level=*/arith::IterMapLevel::Surjective); + /*check_level=*/arith::IterMapLevel::Surjective, + /*simplify_trivial_iterators=*/false); if (v.same_as(op->iter_values)) { return GetRef(op); } else { @@ -397,7 +398,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, for (int i = 0; i < n; i++) { const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); - if (!is_one(factor)) substitute_value = substitute_value * factor + var; + substitute_value = substitute_value * factor + var; analyzer.Bind(var, Range::FromMinExtent(0, factor)); new_loop_vars.emplace_back(std::move(var)); } diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index e4dff51cf9d4f..aa1d219d1c65a 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -86,7 +86,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: with T.block("C"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) - k = T.axis.reduce(512, i2_1 * 32 + i2_2) + k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2) T.reads([A_shared[i, k], B_shared[k, j]]) T.writes([C_local[i, j]]) with T.init(): diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index 8b36ec2f462da..aa7cb09265e9c 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing.schedule_rule import auto_bind from tvm.meta_schedule.testing.space_generation import check_trace diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index c5663a5f2ebd2..4351fe5b6361d 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -281,7 +281,7 @@ def cascade_pool_ops_tile_reordered( ) for h_i, w, kh, kw in T.grid(4, 108, 3, 3): with T.block("pool_1"): - ax0 = T.axis.spatial(1, 0) + ax0 = T.axis.spatial(1, n) ax1 = T.axis.spatial(16, c) ax2 = T.axis.spatial(108, h_o * 4 + h_i) ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw]) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index d70748bc8a03d..c9e6eec029329 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -178,7 +178,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): with T.block("B"): - vi = T.axis.S(128, i1 * 64 + i3) + vi = T.axis.S(128, (i1 + i2) * 64 + i3) vj = T.axis.S(128, j1 * 32 + j2) vk = T.axis.S(128, k1 * 8 + k2) T.reads([A[vi, vj, vk]]) @@ -192,9 +192,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): with T.block("B"): - vi = T.axis.S(128, i1 * 64 + i3) - vj = T.axis.S(128, j1 * 64 + j3) - vk = T.axis.S(128, k1 * 64 + k3) + vi = T.axis.S(128, (i1 + i2) * 64 + i3) + vj = T.axis.S(128, (j1 + j2) * 64 + j3) + vk = T.axis.S(128, (k1 + k2) * 64 + k3) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -206,10 +206,10 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): with T.block("B"): - T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128) vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) - vj = T.axis.S(128, j1) + vj = T.axis.S(128, j0 * 129 + j1) vk = T.axis.S(128, k0 * 43 + k1) + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_transform.py b/tests/python/unittest/test_tir_schedule_transform.py index 6dfd4315ec904..e812587e66761 100644 --- a/tests/python/unittest/test_tir_schedule_transform.py +++ b/tests/python/unittest/test_tir_schedule_transform.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN - -from tvm.tir import Schedule from tvm.script import tir as T +from tvm.tir import Schedule from tvm.tir.schedule.transform import tile_with_tensor_intrin +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN @tvm.script.ir_module @@ -128,11 +127,10 @@ def main( 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4 ): with T.block("conv2d_NCHWc_int8"): - n = T.axis.spatial(1, 0) - oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1]) - kh = T.axis.reduce(1, 0) - kw = T.axis.reduce(1, 0) - ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1]) + n, oc_chunk, oh, ow = T.axis.remap("SSSS", [i0, i1, i2, i3]) + oc_block = T.axis.spatial(16, i4_0 * 16 + i4_1) + kh, kw, ic_outer, ic_f_inner = T.axis.remap("RRRR", [i5, i6, i7, i8]) + ic_s_inner = T.axis.reduce(4, i9_0 * 4 + i9_1) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], @@ -165,14 +163,10 @@ def test_tile_with_tensor_intrin_dense_vnni(): def test_tile_with_tensor_intrin_conv2d_nchwc_vnni(): s = Schedule(Conv2dNCHWcVNNIModule) block = s.get_block("conv2d_NCHWc_int8") - tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) - tiled_loops = s.get_loops(block) - assert len(tiled_loops) == 12 assert s.get(tiled_loop) == s.get(tiled_loops[-2]) - tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled) From 06c443e9959452c6da3a911fe0c11e08c5554477 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Jun 2022 16:59:17 -0700 Subject: [PATCH 045/181] [Bugfix][TIR] compute-at/fuse/split dtype mismatch (#11582) The schedule primitives, including compute-at, fuse and split usually generate loop variables with `dtype=int32` as default. However, in some models, there are usecases where int64 are part of tensor shapes, which leads to unexpected behavior in scheduling. This PR brings the fix to existing known issues. --- src/tir/schedule/primitive/compute_at.cc | 5 +- .../schedule/primitive/loop_transformation.cc | 21 +++++-- .../unittest/test_tir_schedule_compute_at.py | 24 ++++++-- .../unittest/test_tir_schedule_split_fuse.py | 61 ++++++++++++++++++- 4 files changed, 97 insertions(+), 14 deletions(-) diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 7f1d74ac20214..7b0d749f03dcf 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -194,7 +194,7 @@ struct BlockVarDomainInfo { } return; } - // simplify intsets + // simplify intset dom = to_simplified(dom); bound = to_simplified(bound); // if can proof the dom is within bound, remove bound @@ -242,7 +242,8 @@ class ScopeReconstructor : private StmtMutator { for (int i = 0; i < n_iters; ++i) { Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom); if (preserve_unit_loops || !is_one(iter_dom->extent)) { - Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); + int bits = std::max(iter_dom->min.dtype().bits(), iter_dom->extent.dtype().bits()); + Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(bits)); loop_vars.push_back(var); loop_extents.push_back(analyzer->Simplify(iter_dom->extent)); iter_values.push_back(iter_dom->min + var); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index e374d1f3c5e77..bb505bca33763 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -54,7 +54,7 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); if (Optional ret = vmap_(var)) { - return ret.value(); + return tvm::cast(var.dtype(), ret.value()); } else { return std::move(var); } @@ -391,15 +391,24 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, arith::Analyzer analyzer; CheckLoopStartsWithZero(self, loop_sref, &analyzer); + // Find the most common dtype + DataType dtype; + { + int bits = loop->loop_var.dtype().bits(); + for (const PrimExpr& factor : factors) { + bits = std::max(bits, factor.dtype().bits()); + } + dtype = DataType::Int(bits); + } int n = factors.size(); - PrimExpr substitute_value = 0; + PrimExpr substitute_value = make_const(dtype, 0); std::vector new_loop_vars; new_loop_vars.reserve(n); for (int i = 0; i < n; i++) { const PrimExpr& factor = factors[i]; - Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); + Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)).copy_with_dtype(dtype); substitute_value = substitute_value * factor + var; - analyzer.Bind(var, Range::FromMinExtent(0, factor)); + analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); new_loop_vars.emplace_back(std::move(var)); } Map opaque_block_reuse; @@ -481,11 +490,13 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { // Step 2. Create fused loop var and replace the original loop vars std::string suffix; int n = loops.size(); + int bits = loops[0]->loop_var.dtype().bits(); for (int i = 1; i < n; i++) { suffix += "_" + loops[i]->loop_var->name_hint; + bits = std::max(bits, loops[i]->loop_var.dtype().bits()); } suffix += "_fused"; - Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits)); Array substitute_value; substitute_value.resize(loops.size()); PrimExpr lower = 1; diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index f477367adfad3..3772d9a4e0fec 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -15,13 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys - import pytest - import tvm import tvm.testing -from tvm import tir +from tvm import te, tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -1335,5 +1332,24 @@ def test_fail_all_producers_under_loop(): sch.reverse_compute_at(block, loop) +def test_compute_at_int64_loop(): + def _create_prim_func(): + n = te.var("n", dtype="int64") + m = te.var("m", dtype="int64") + A = te.placeholder((n, m), name="A", dtype="float32") + B = te.placeholder((n, m), name="B", dtype="float32") + C = te.compute((n, m), lambda i, j: A[i, j] + B[i, j], name="C") + D = te.compute((n, m), lambda i, j: C[i, j] + 1.0, name="D") + return te.create_prim_func([A, B, D]) + + mod = _create_prim_func() + sch = tir.Schedule(mod, debug_mask="all") + block_c = sch.get_block("C") + block_d = sch.get_block("D") + i, _ = sch.get_loops(block_d) + sch.compute_at(block_c, i) + verify_trace_roundtrip(sch=sch, mod=mod) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index c9e6eec029329..0bfac4e425b95 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys - import pytest import tvm import tvm.testing -from tvm import tir +from tvm import te, tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -582,5 +580,62 @@ def zero_dim_added( tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) +@pytest.mark.skip("Pending fix in affine analysis") +def test_fuse_int64(): + def _create_prim_func(): + n = te.const(16, "int32") + m = te.const(32, "int64") + A = te.placeholder((n, m), name="A", dtype="int32") + B = te.compute((n, m), lambda i, j: A[i, j] + 1, name="B") + return te.create_prim_func([A, B]) + + mod = _create_prim_func() + sch = tir.Schedule(mod, debug_mask="all") + i, j = sch.get_loops(sch.get_block("B")) + sch.fuse(i, j) + verify_trace_roundtrip(sch=sch, mod=mod) + + +def test_split_int64_extent_with_mixed_factors(): + def _create_prim_func(): + m = te.const(384, "int64") + A = te.placeholder((m,), name="A", dtype="float32") + B = te.compute((m,), lambda i: A[i] + 1, name="B") + return te.create_prim_func([A, B]) + + mod = _create_prim_func() + sch = tir.Schedule(mod, debug_mask="all") + (i,) = sch.get_loops(sch.get_block("B")) + sch.split( + i, + factors=[ + te.const(1, "int64"), + te.const(512, "int32"), + ], + ) + + +def test_split_int64_extent_with_int32_factors(): + def _create_prim_func(): + m = te.const(12, "int64") + A = te.placeholder((m,), name="A", dtype="float32") + B = te.compute((m,), lambda i: A[i] + 1, name="B") + return te.create_prim_func([A, B]) + + mod = _create_prim_func() + sch = tir.Schedule(mod, debug_mask="all") + (i,) = sch.get_loops(sch.get_block("B")) + sch.split( + i, + factors=[ + te.const(1, "int32"), + te.const(1, "int32"), + te.const(3, "int32"), + te.const(1, "int32"), + te.const(4, "int32"), + ], + ) + + if __name__ == "__main__": tvm.testing.main() From 8a568bc823fa7c8c3d37ff15deb4a8faef6d0bbb Mon Sep 17 00:00:00 2001 From: "Kathryn (Jinqi) Chen" <65606304+Kathryn-cat@users.noreply.github.com> Date: Sun, 5 Jun 2022 19:44:52 -0700 Subject: [PATCH 046/181] [MetaSchedule] exposed method: TuneContextNodeInitialize (#11576) I exposed the initialize() method for TuneContextNode on the C++ side and added a corresponding method to TuneContext class on the Python side, so that we do not need to call initialize_with_tune_context for every scheduling rule. --- python/tvm/meta_schedule/tune_context.py | 5 +++++ src/meta_schedule/tune_context.cc | 2 ++ 2 files changed, 7 insertions(+) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index ef2e4bcd8e6d9..19ab0a40cf617 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -129,3 +129,8 @@ def __init__( rand_state, num_threads, ) + + def initialize(self): + """Initialize the tuning context""" + + _ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 382dd961dee0e..3607e3050803e 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -89,6 +89,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") }); TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") + .set_body_method(&TuneContextNode::Initialize); } // namespace meta_schedule } // namespace tvm From 8038987411471bbdd03edba75271a1c00d571f23 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Jun 2022 19:45:16 -0700 Subject: [PATCH 047/181] [MetaSchedule] Fix Summary Format for Invalid Runs (#11584) Previously for invalid tasks, MetaSchedule prints a huge number in latency which is aesthetically unacceptable. For example, ``` 69 | fused_cast_add_cast_3 | 16777216 | 2 | 0.0000 | 10000000000000000019156750857346687362159551272651920111528035145993793242039887559612361451081803235328.0000 | 20000000000000000038313501714693374724319102545303840223056070291987586484079775119224722902163606470656.0000 | 64 | ``` This PR fixes this behavior and turns the huge number into "N/A". --- .../task_scheduler/gradient_based.cc | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index a95dbba6c3e14..f8cc9d5514941 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -79,10 +79,14 @@ class GradientBasedNode final : public TaskSchedulerNode { << /*name=*/record.task->task_name.value() // << /*flops=*/static_cast(record.flop) // << /*weight=*/static_cast(record.weight); - if (trials == 0) { + double latency = 1e9; + if (trials > 0) { + latency = record.best_time_cost_history.back(); + } + if (latency >= 1e9) { row << /*speed=*/"N/A" << /*latency=*/"N/A" << /*weighted_latency=*/"N/A"; } else { - double latency = record.best_time_cost_history.back() * 1000.0; + latency *= 1000.0; double speed = record.flop / latency / 1000.0; double weighted_latency = latency * record.weight; row << /*speed=*/speed << /*latency=*/latency << /*weighted_latency=*/weighted_latency; @@ -139,10 +143,15 @@ class GradientBasedNode final : public TaskSchedulerNode { int n = record.best_time_cost_history.size(); ICHECK_GE(n, 1); double best = record.best_time_cost_history[n - 1]; - double g1 = (n >= 1 + w) ? (record.best_time_cost_history[n - 1 - w] - best) / w : 0.0; - double g2 = best / n; - double g = alpha * g1 + (1 - alpha) * g2; - grad.push_back(g * record.weight); + if (best < 1e9) { + double g1 = (n >= 1 + w) ? (record.best_time_cost_history[n - 1 - w] - best) / w : 0.0; + double g2 = best / n; + double g = alpha * g1 + (1 - alpha) * g2; + grad.push_back(g * record.weight); + } else { + // If the best time cost is unavailable, it means some task is not valid. Skip it. + grad.push_back(-1e9); + } } auto max_grad = std::max_element(grad.begin(), grad.end()); auto min_grad = std::min_element(grad.begin(), grad.end()); From 283542f68a8759eebca97626b983909f55c64699 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Sun, 5 Jun 2022 20:00:19 -0700 Subject: [PATCH 048/181] [CI][DOC] Fix incorrect commands in docs/readme.md (#11583) Fix incorrect commands in docs/readme.md --- docs/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/README.md b/docs/README.md index 520fea60ca28a..0ccb3cd3b954d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -79,14 +79,14 @@ the path that matches the regular expression pattern. For example, to only build tutorials under `/vta/tutorials`, run ```bash -python tests/scripts/ci.py docs --tutorials=/vta/tutorials +python tests/scripts/ci.py docs --tutorial-pattern=/vta/tutorials ``` To only build one specific file, do ```bash # The slash \ is used to get . in regular expression -python tests/scripts/ci.py docs --tutorials=file_name\.py +python tests/scripts/ci.py docs --tutorial-pattern=file_name\.py ``` ## Helper Scripts @@ -95,14 +95,14 @@ You can run the following script to reproduce the CI sphinx pre-check stage. This script skips the tutorial executions and is useful to quickly check the content. ```bash -python tests/scripts/ci.py docs --precheck +tests/scripts/task_python_docs.sh ``` The following script runs the full build which includes tutorial executions. You will need a GPU CI environment. ```bash -python tests/scripts/ci.py --precheck --full +python tests/scripts/ci.py docs --full ``` ## Define the Order of Tutorials From bf4b8f5c766be8320df8d792a8c063b7b42c69f5 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Sun, 5 Jun 2022 22:11:45 -0500 Subject: [PATCH 049/181] split test_forward_math_api function (#11537) --- .../frontend/paddlepaddle/test_forward.py | 237 ++++++++++++++---- 1 file changed, 193 insertions(+), 44 deletions(-) diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 56ec3a4e5469e..8b696404e2b0c 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -1358,7 +1358,10 @@ def slice4(inputs): @tvm.testing.uses_gpu -def test_forward_math_api(): +def run_math_api(func): + api_name = func.__name__.split("_")[-1] + print("func_name:", api_name) + class MathAPI(nn.Layer): def __init__(self, api_name): super(MathAPI, self).__init__() @@ -1371,52 +1374,198 @@ def __init__(self, api_name): def forward(self, inputs): return self.func(inputs) - api_list = [ - "abs", - "acos", - "asin", - "atan", - "ceil", - "cos", - "cosh", - "elu", - "erf", - "exp", - "floor", - "hardshrink", - "hardtanh", - "log_sigmoid", - "log_softmax", - "log", - "log2", - "log10", - "log1p", - "reciprocal", - "relu", - "relu6", - "round", - "rsqrt", - "selu", - "sigmoid", - "sign", - "sin", - "sinh", - "softplus", - "softsign", - "sqrt", - "square", - "swish", - "tan", - "tanh", - ] input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]] for input_shape in input_shapes: input_data = paddle.rand(input_shape, dtype="float32") - for api_name in api_list: - if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]: - # avoid illegal input, all elements should be positive - input_data = paddle.uniform(input_shape, min=0.01, max=0.99) - verify_model(MathAPI(api_name), input_data=input_data) + if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]: + # avoid illegal input, all elements should be positive + input_data = paddle.uniform(input_shape, min=0.01, max=0.99) + verify_model(MathAPI(api_name), input_data=input_data) + + +@run_math_api +def test_forward_abs(): + pass + + +@run_math_api +def test_forward_acos(): + pass + + +@run_math_api +def test_forward_abs(): + pass + + +@run_math_api +def test_forward_atan(): + pass + + +@run_math_api +def test_forward_ceil(): + pass + + +@run_math_api +def test_forward_cos(): + pass + + +@run_math_api +def test_forward_cosh(): + pass + + +@run_math_api +def test_forward_elu(): + pass + + +@run_math_api +def test_forward_erf(): + pass + + +@run_math_api +def test_forward_exp(): + pass + + +@run_math_api +def test_forward_floor(): + pass + + +@run_math_api +def test_forward_hardshrink(): + pass + + +@run_math_api +def test_forward_hardtanh(): + pass + + +@run_math_api +def test_forward_log_sigmoid(): + pass + + +@run_math_api +def test_forward_log_softmax(): + pass + + +@run_math_api +def test_forward_log(): + pass + + +@run_math_api +def test_forward_log2(): + pass + + +@run_math_api +def test_forward_log10(): + pass + + +@run_math_api +def test_forward_log1p(): + pass + + +@run_math_api +def test_forward_reciprocal(): + pass + + +@run_math_api +def test_forward_relu(): + pass + + +@run_math_api +def test_forward_round(): + pass + + +@run_math_api +def test_forward_rsqrt(): + pass + + +@run_math_api +def test_forward_selu(): + pass + + +@run_math_api +def test_forward_sigmoid(): + pass + + +@run_math_api +def test_forward_sign(): + pass + + +@run_math_api +def test_forward_sin(): + pass + + +@run_math_api +def test_forward_softplus(): + pass + + +@run_math_api +def test_forward_sqrt(): + pass + + +@run_math_api +def test_forward_square(): + pass + + +@run_math_api +def test_forward_sin(): + pass + + +@run_math_api +def test_forward_softsign(): + pass + + +@run_math_api +def test_forward_sqrt(): + pass + + +@run_math_api +def test_forward_square(): + pass + + +@run_math_api +def test_forward_swish(): + pass + + +@run_math_api +def test_forward_tan(): + pass + + +@run_math_api +def test_forward_tanh(): + pass @tvm.testing.uses_gpu From b555bf5481d3eb261427850cea286c162aa3d2e3 Mon Sep 17 00:00:00 2001 From: M Date: Mon, 6 Jun 2022 16:54:09 +0800 Subject: [PATCH 050/181] fix bmm quantization realize (#11586) --- src/relay/quantize/realize.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 301dc1a09f396..5766c62eaa433 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -511,13 +511,14 @@ Expr BatchMatmulRealize(const Call& ref_call, const Array& new_args, const Expr ldata = lhs->data; Expr rdata = rhs->data; - DataType dtype = cfg->dtype_input; + DataType dtype_input = cfg->dtype_input; + DataType dtype_weight = cfg->dtype_weight; - if (lhs->dtype != dtype) { - ldata = Cast(ldata, dtype); + if (lhs->dtype != dtype_input) { + ldata = Cast(ldata, dtype_input); } - if (rhs->dtype != dtype) { - rdata = Cast(rdata, dtype); + if (rhs->dtype != dtype_weight) { + rdata = Cast(rdata, dtype_weight); } const auto ref_attrs = ref_call->attrs.as(); From 609d6af17605d657909549e908876f4335206bd6 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 6 Jun 2022 13:29:07 +0100 Subject: [PATCH 051/181] [microNPU] Fix output mismatch in Leaky ReLU (#11397) * [microNPU] Fix output mismatch in Leaky ReLU All codegen tests have been running with a representative dataset between 0,1 which masked an output mismatch in Leaky ReLU when compared to TFLite kernels. This issue can be replicated by replacing the representative dataset range with something like -1,1. To fix this mismatch, we use the same implementation for calculating LUT values as Vela which uses arithmetic constrained to quantized values, rather than the previously used floating point calculations. Change-Id: I0ed52215acd27722873be609271971b6fc4aaef1 * fix lint Change-Id: Ica7de0c000ee015e79fe10985b2ec7a9b341861f * fix lint again Change-Id: I005d90ad248bfff7090f99d161eefbdc962cba48 --- .../relay/backend/contrib/ethosu/legalize.py | 88 ++++++++++++------- .../contrib/test_ethosu/test_codegen.py | 6 +- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index d83cd403ca144..c940abdeab5f5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -16,10 +16,11 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" -from typing import List, Type, Callable, Any, Dict +from typing import List, Type, Callable import math import numpy as np # type: ignore +from ethosu.vela import scaling, fp_math import tvm # type: ignore from tvm import relay @@ -132,7 +133,6 @@ def get_lut_from_func( ofm_scale: float, ofm_zp: int, func: Callable[[float], float], - func_params: Dict[str, Any], ) -> List[int]: """Calculates the values of the lookup table based on the calculation function""" @@ -142,7 +142,7 @@ def get_lut_from_func( qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max for x in range(qmin, qmax + 1): x_real = ifm_scale * (x - ifm_zp) - out_real = func(x_real, **func_params) + out_real = func(x_real) lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale)) lut_result = min(qmax, max(qmin, lut_result)) lut_values.append(lut_result) @@ -165,29 +165,10 @@ def __init__( self.activation_type = activation_type self.calc_func = calc_func - def get_calc_func_params(self, expr: tvm.relay.Expr) -> Dict[str, Any]: - """ - Overridable method that can be used to extract additional arguments - for passing to calc_func. - - Parameters - ---------- - expr : tvm.relay.Expr - The matched composite activation function. - - Returns - ------- - Dict[str, Any] - Maps argument name to argument value. - """ - return {} - def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map): params = self.params_class(post.op.body) params.ifm.tensor = post.args[0] - calc_func_params = self.get_calc_func_params(post.op) - input_scale = float(params.ifm.q_params.scale_f32) input_zp = int(params.ifm.q_params.zero_point) output_scale = float(params.ofm.q_params.scale_f32) @@ -199,7 +180,6 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c output_scale, output_zp, self.calc_func, - calc_func_params, ) lut = relay.const(lut_values, dtype=params.ifm.dtype) @@ -257,19 +237,65 @@ def leaky_relu_calc_func(x: float, alpha: float) -> float: return x if x >= 0 else x * alpha -class LeakyReLURewriter(LutActivationRewriter): +class LeakyReLURewriter(DFPatternCallback): """This pass adds leaky relu as a LUT for identity op.""" def __init__(self): - super().__init__( - params_class=ethosu_patterns.LeakyReLUParams, - activation_type="LUT", - calc_func=leaky_relu_calc_func, + super().__init__(require_type=True, rewrite_once=True) + self.params_class = ethosu_patterns.LeakyReLUParams + self.pattern = wildcard().has_attr({"Composite": self.params_class.composite_name})( + wildcard() ) - def get_calc_func_params(self, expr: tvm.relay.Expr) -> Dict[str, Any]: - params = ethosu_patterns.LeakyReLUParams(expr.body) - return {"alpha": params.alpha} + def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map): + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[0] + + input_scale = np.double(float(params.ifm.q_params.scale_f32)) + input_zp = int(params.ifm.q_params.zero_point) + output_scale = np.double(float(params.ofm.q_params.scale_f32)) + output_zp = int(params.ofm.q_params.zero_point) + + alpha = params.alpha + + # The calculation of the LUT values is similar to that in Vela + # convert_lrelu_to_lut(op, arch) + # (https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#864) # pylint: disable=line-too-long + alpha_scalar = 1 + alpha_scale, alpha_shift = scaling.elementwise_mul_scale(input_scale, alpha, output_scale) + identity_scale, identity_shift = scaling.elementwise_mul_scale(input_scale, 1, output_scale) + + dtype = params.ifm.dtype + qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max + + def calculate_lut_value(i): + zp_shift = ( + fp_math.multiply_by_quantized_multiplier( + alpha_scalar * (i - input_zp), alpha_scale, alpha_shift + ) + if i < input_zp + else fp_math.multiply_by_quantized_multiplier( + i - input_zp, identity_scale, identity_shift + ) + ) + + return min(qmax, max(qmin, output_zp + zp_shift)) + + values = list(map(calculate_lut_value, range(qmin, qmax + 1))) + lut = relay.const(values, dtype=dtype) + + # We baked the requantization into the LUT, so we don't requantize the identity operator + identity = ethosu_ops.ethosu_identity( + ifm=params.ifm.tensor, + lut=lut, + ifm_scale=input_scale, + ifm_zero_point=input_zp, + ofm_scale=input_scale, + ofm_zero_point=input_zp, + activation="LUT", + ) + + return identity class Conv2DRewriter(DFPatternCallback): diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b6b78c3357605..b73ebd5361192 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1022,7 +1022,11 @@ def leaky_relu_func(x): return tf.nn.leaky_relu(x, alpha=alpha) infra.compare_tvm_with_tflite( - leaky_relu_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type) + leaky_relu_func, + [ifm_shape], + accel_type, + enable_cascader=is_u55_accel_type(accel_type), + ranges=[(-1, 1)], ) From 1aac4d6826192383a755369ab5ccfe4876e8902b Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 6 Jun 2022 15:10:22 +0100 Subject: [PATCH 052/181] [microNPU] Optimize separate padding operation for conv2d (#11468) Optimizes a case where padding appears as a separate nn.pad operation followed by a qnn.conv2d. If possible, the nn.pad will be partitioned and offloaded together with the qnn.conv2d operation, as opposed to separately. As a fallback, both operations will be considered separately. cc Mousius NicolaLancellotti ekalda manupa-arm --- python/tvm/relay/op/contrib/ethosu.py | 66 +++++- tests/python/contrib/test_ethosu/infra.py | 11 +- .../contrib/test_ethosu/test_codegen.py | 68 +++++- .../contrib/test_ethosu/test_legalize.py | 216 ++++++++++++++++++ 4 files changed, 349 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index dfdc0c82fb1e9..806bf6dce2e89 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -201,6 +201,8 @@ def __init__(self, func_body: tvm.relay.Function): from tvm.relay.backend.contrib.ethosu.util import RequantArgs activation = None + separate_padding = None + if str(func_body.op) in self.activation_map.keys(): activation = func_body requantize_op = activation.args[0] @@ -208,8 +210,11 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op = func_body bias_add = requantize_op.args[0] qnn_conv2d = bias_add.args[0] + if isinstance(qnn_conv2d.args[0], relay.Call) and str(qnn_conv2d.args[0].op) == "nn.pad": + separate_padding = qnn_conv2d.args[0] data_layout = qnn_conv2d.attrs.data_layout self.kernel_layout = qnn_conv2d.attrs.kernel_layout + # We consider the weights & biases as params as it should be a Constant self.weights = TensorParams( qnn_conv2d.args[QConv2DArgs.WEIGHTS.value], @@ -224,8 +229,11 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op.args[RequantArgs.IFM_SCALE.value], requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], ) + ifm_tensor = ( + separate_padding.args[0] if separate_padding else qnn_conv2d.args[QConv2DArgs.IFM.value] + ) self.ifm = TensorParams( - qnn_conv2d.args[QConv2DArgs.IFM.value], + ifm_tensor, data_layout, qnn_conv2d.args[QConv2DArgs.IFM_SCALE.value], qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value], @@ -237,7 +245,10 @@ def __init__(self, func_body: tvm.relay.Function): requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], ) attrs = qnn_conv2d.attrs - self.padding = attrs.padding + + pad_value = int(qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value].data.asnumpy()) + self.padding = self.extract_padding(attrs.padding, separate_padding, pad_value) + self.strides = attrs.strides self.dilation = attrs.dilation self.activation = activation @@ -250,6 +261,37 @@ def __init__(self, func_body: tvm.relay.Function): if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]: self.is_depthwise = True + @staticmethod + def extract_padding( + operator_padding: Tuple[int, int, int, int], + separate_padding: relay.Call, + pad_value: int, + ) -> Optional[Tuple[int, int, int, int]]: + """ + Convolution operations can sometimes have padding represented as a separate + padding operation before the convolution operation itself. Here we can check + whether these representations can be combined into a single padding attribute + as part of the NPU convolution itself. If the padding specified by the separate + nn.pad operation is not supported, None will be returned. This will cause the + nn.pad to be offloaded separately. + """ + if separate_padding is None: + return operator_padding + if pad_value != int(separate_padding.args[1].data.asnumpy()): + return None + pad_width = separate_padding.attrs["pad_width"] + if len(pad_width) != 4: + return None + if list(pad_width[0]) != [0, 0] or list(pad_width[3]) != [0, 0]: + return None + top, left, bottom, right = operator_padding + return [ + top + pad_width[1][0], + left + pad_width[2][0], + bottom + pad_width[1][1], + right + pad_width[2][1], + ] + def is_valid(self) -> bool: """ This function checks whether QnnConv2D has compatible attributes with the NPU @@ -267,7 +309,7 @@ def is_valid(self) -> bool: return False if not check_dilation(self.dilation): return False - if not check_padding(self.padding, self.padding_bounds): + if not self.padding or not check_padding(self.padding, self.padding_bounds): return False legal_groups = [1, self.ofm.shape[3]] if self.groups not in legal_groups: @@ -437,7 +479,7 @@ def is_valid(self): return False if not check_dilation(self.dilation): return False - if not check_padding(self.padding, self.padding_bounds): + if not self.padding or not check_padding(self.padding, self.padding_bounds): return False if self.weights.layout != "HWOI": return False @@ -453,8 +495,14 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for qnn.conv2D with optional fused RELU activation. """ + optional_pad = is_op("nn.pad")(wildcard(), is_constant()) qnn_conv2d = is_op("qnn.conv2d")( - wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + optional_pad | wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), ).has_attr({"kernel_layout": "HWIO"}) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( @@ -468,8 +516,14 @@ def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation. """ + optional_pad = is_op("nn.pad")(wildcard(), is_constant()) qnn_conv2d = is_op("qnn.conv2d")( - wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + optional_pad | wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), ).has_attr({"kernel_layout": "HWOI"}) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index a1bdcb47e62d1..1f999781e3b1b 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -473,10 +473,17 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1] assert len(strides) == 2 assert len(dilation) == 2 assert len(kernel_shape) == 2 - if padding.lower() == "valid": + if isinstance(padding, tuple): + h = ( + ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0] + padding[0] + padding[2] + ) // strides[0] + w = ( + ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1] + padding[1] + padding[3] + ) // strides[1] + elif padding.lower() == "valid": h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0]) w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1]) - if padding.lower() == "same": + elif padding.lower() == "same": h = math.ceil(ifm_shape[1] / strides[0]) w = math.ceil(ifm_shape[2] / strides[1]) ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b73ebd5361192..2d3489889e8ab 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -72,13 +72,43 @@ def conv2d(x): padding=padding, dilations=dilation, ) - if activation: + if activation == "RELU": op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type) +def test_tflite_conv2d_with_separate_pad(): + np.random.seed(0) + + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + dilation = (2, 1) + padding = (0, 0, 1, 1) + + @tf.function + def conv2d(x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + infra.compare_tvm_with_tflite(conv2d, [ifm_shape], "ethos-u55-256") + + @pytest.mark.parametrize("ifm_shape", [(1, 214, 227, 2), (1, 27, 42, 3)]) @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @@ -120,7 +150,7 @@ def conv2d_double(x): padding=padding, dilations=dilation, ) - if activation: + if activation == "RELU": op2 = tf.nn.relu(op2) return op2 @@ -156,7 +186,7 @@ def conv_invalid_scale(x): padding=padding, dilations=dilation, ) - if activation: + if activation == "RELU": op = tf.nn.relu(op) return op @@ -191,13 +221,43 @@ def depthwise_conv2d(x): op = tf.nn.depthwise_conv2d( x, weight, strides=tf_strides, padding=padding, dilations=dilation ) - if activation_function: + if activation_function == "RELU": op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type) +def test_tflite_depthwise_conv2d_with_separate_pad(): + np.random.seed(0) + + ifm_shape = (1, 23, 32, 7) + kernel_shape = (1, 2) + strides = (3, 2) + dilation = (1, 1) + padding = (0, 0, 1, 1) + + @tf.function + def depthwise_conv2d(x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.depthwise_conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], "ethos-u55-256") + + @pytest.mark.parametrize( "accel_type", ACCEL_TYPES, diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 2dd5eff91373b..3f8b5f7d5b583 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -347,6 +347,114 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_conv2d_with_separate_padding_legalize(): + dtype = "int8" + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + dilation = (2, 1) + padding = (0, 0, 1, 1) + + def create_tflite_graph_single(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation + ) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 3 + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + assert list(op.attrs.padding) == list(padding) + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + + conv2d_pattern_table = [ + ( + ethosu.QnnConv2DParams.composite_name, + ethosu.qnn_conv2d_pattern(), + lambda pat: ethosu.QnnConv2DParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph_single() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, conv_params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], conv_params) + mod = partition_ethosu_by_table(mod, conv2d_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) @pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @@ -458,6 +566,114 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_depthwise_conv2d_with_separate_padding_legalize(): + dtype = "int8" + ifm_shape = (1, 23, 32, 7) + kernel_shape = (1, 2) + strides = (3, 2) + dilation = (1, 1) + padding = (0, 0, 1, 1) + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.depthwise_conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + dilations=dilation, + ) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + op = ext_func.body + ofm_channels = op.attrs.ofm_channels + + # check IFM + ifm = op.args[0].checked_type + assert list(ifm.shape) == list(ifm_shape) + assert str(ifm.dtype) == dtype + assert ifm.shape[3] == ofm_channels + + # check OFM + ofm = op.checked_type + expected_ofm_shape = infra.compute_ofm_shape( + ifm_shape, padding, kernel_shape, strides, dilation + ) + assert list(ofm.shape) == list(expected_ofm_shape) + assert str(ofm.dtype) == dtype + assert ofm.shape[3] == ofm_channels + + # check weights + weights_ohwi = op.args[1].data.asnumpy() + assert str(weights_ohwi.dtype) == dtype + assert weights_ohwi.shape[0] == ofm_channels + assert weights_ohwi.shape[1] == kernel_shape[0] + assert weights_ohwi.shape[2] == kernel_shape[1] + assert weights_ohwi.shape[3] == 1 # only depth multiplier 1 is supported + + # Check that scale_bias matches weight tensor + assert list(op.args[2].checked_type.shape)[0] == ofm_channels + + assert list(op.attrs.padding) == list(padding) + assert op.attrs.ofm_channels == ofm_channels + assert list(op.attrs.strides) == list(strides) + assert list(op.attrs.dilation) == list(dilation) + + depthwise_pattern_table = [ + ( + ethosu.QnnDepthwiseConv2DParams.composite_name, + ethosu.qnn_depthwise_conv2d_pattern(), + lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, depthwise_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) @pytest.mark.parametrize( From 9d6599c928ec4de1aede59927fcc5f651096e358 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 6 Jun 2022 08:49:22 -0700 Subject: [PATCH 053/181] [PROFILER] Add configuration information to profiler (#11530) Configuration is a place to store extra information related to the specific profiler run. Right now it is just the executor used and the number of threads. The roofline analysis also adds peak flops and peak bandwidth. --- include/tvm/runtime/profiling.h | 17 ++- python/tvm/runtime/profiling/__init__.py | 11 +- python/tvm/utils/roofline.py | 5 +- src/node/structural_hash.cc | 1 + .../debug/graph_executor_debug.cc | 2 +- src/runtime/profiling.cc | 111 ++++++++++++------ src/runtime/vm/profiler/vm.cc | 6 +- .../python/unittest/test_runtime_profiling.py | 3 + 8 files changed, 109 insertions(+), 47 deletions(-) diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 0163f0c2e49e1..83c26933be45b 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_PROFILING_H_ #include +#include #include #include #include @@ -192,6 +193,11 @@ class ReportNode : public Object { * because these metrics include the overhead of the executor. */ Map> device_metrics; + /*! Configuration used for this profiling run. Includes number of threads, executor. + * + * Values must be an object type that can be used with device_metrics. + */ + Map configuration; /*! \brief Output `calls` in CSV format. * * Note that this does not include `device_metrics`, it only includes per-call metrics. @@ -255,9 +261,11 @@ class Report : public ObjectRef { /*! Construct a Report from a set of calls (with associated metrics) and per-device metrics. * \param calls Function calls and associated metrics. * \param device_metrics Per-device metrics for overall execution. + * \param configuration Configuration data specific to this profiling run. */ explicit Report(Array> calls, - Map> device_metrics); + Map> device_metrics, + Map configuration); /*! Deserialize a Report from a JSON object. Needed for sending the report over RPC. * \param json Serialized json report from `ReportNode::AsJSON`. @@ -366,8 +374,10 @@ class Profiler { * \param devs The list of devices the profiler will be running on. Should * include all devices used by profiled operators. * \param metric_collectors Additional `MetricCollector`s to use with this profiler. + * \param configuration Additional configuration data to add to the outputted profiling report. */ - explicit Profiler(std::vector devs, std::vector metric_collectors); + explicit Profiler(std::vector devs, std::vector metric_collectors, + std::unordered_map configuration = {}); /*! \brief Start the profiler. * * This function should only be called once per object. @@ -400,7 +410,7 @@ class Profiler { * \returns A `Report` that can either be formatted as CSV (with `.AsCSV`) * or as a human readable table (with `.AsTable`). */ - profiling::Report Report(bool aggregate = true, bool sort = true); + profiling::Report Report(); /*! \brief Check if the profiler is currently running. * \returns Whether or not the profiler is running. */ @@ -412,6 +422,7 @@ class Profiler { std::vector calls_; std::stack in_flight_; std::vector collectors_; + std::unordered_map configuration_; }; /* \brief A duration in time. */ diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 5737790378278..347d8b9f94f15 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -36,7 +36,10 @@ class Report(Object): """ def __init__( - self, calls: Sequence[Dict[str, Object]], device_metrics: Dict[str, Dict[str, Object]] + self, + calls: Sequence[Dict[str, Object]], + device_metrics: Dict[str, Dict[str, Object]], + configuration: Dict[str, Object], ): """Construct a profiling report from a list of metrics and per-device metrics. @@ -47,8 +50,12 @@ def __init__( device_metrics : Dict[str, Dict[str, Object]] Per device metrics. + + configuration : Dict[str, Object] + Configuration of TVM for this profiling run. Includes number of + threads, executor. """ - self.__init_handle_by_constructor__(_ffi_api.Report, calls, device_metrics) + self.__init_handle_by_constructor__(_ffi_api.Report, calls, device_metrics, configuration) def csv(self): """Convert this profiling report into CSV format. diff --git a/python/tvm/utils/roofline.py b/python/tvm/utils/roofline.py index 8a17b9f003123..6cfca81c5c420 100644 --- a/python/tvm/utils/roofline.py +++ b/python/tvm/utils/roofline.py @@ -400,7 +400,10 @@ def roofline_from_existing( new_calls.append(call) else: new_calls.append(call) - return profiling.Report(new_calls, report.device_metrics) + new_configuration = dict(report.configuration.items()) + new_configuration["Estimated Peak FMA FLOP/s"] = profiling.Ratio(peak_flops) + new_configuration["Estimated Peak Bandwidth (byte/second)"] = profiling.Ratio(peak_bandwidth) + return profiling.Report(new_calls, report.device_metrics, new_configuration) def roofline_analysis( diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index e97e5f41bfc28..23811e2190784 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -521,6 +521,7 @@ struct ReportNodeTrait { static void VisitAttrs(runtime::profiling::ReportNode* report, AttrVisitor* attrs) { attrs->Visit("calls", &report->calls); attrs->Visit("device_metrics", &report->device_metrics); + attrs->Visit("configuration", &report->configuration); } static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index bd3b0db0403f3..4a950153954ff 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -294,7 +294,7 @@ class GraphExecutorDebug : public GraphExecutor { */ profiling::Report Profile(Array collectors) { std::vector cs(collectors.begin(), collectors.end()); - profiling::Profiler prof(devices_, cs); + profiling::Profiler prof(devices_, cs, {{String("Executor"), String("Graph")}}); // warm up. 1 iteration does not seem enough. for (int i = 0; i < 3; i++) { diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 9499a6e7a5bbb..9f95bf18f74b2 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -105,8 +105,9 @@ TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); namespace profiling { -Profiler::Profiler(std::vector devs, std::vector metric_collectors) - : devs_(devs), collectors_(metric_collectors) { +Profiler::Profiler(std::vector devs, std::vector metric_collectors, + std::unordered_map configuration) + : devs_(devs), collectors_(metric_collectors), configuration_(configuration) { is_running_ = false; std::vector wrapped_devs; for (auto dev : devs) { @@ -117,6 +118,9 @@ Profiler::Profiler(std::vector devs, std::vector metric } // reset the thread pool so that PAPI eventset hooks are set in all threads. threading::ResetThreadPool(); + + configuration_[String("Number of threads")] = + ObjectRef(make_object(threading::NumThreads())); } void Profiler::Start() { @@ -279,7 +283,7 @@ String ReportNode::AsCSV() const { } namespace { -void print_metric(std::ostream& os, ObjectRef o) { +void metric_as_json(std::ostream& os, ObjectRef o) { if (o.as()) { os << "{\"string\":" << "\"" << Downcast(o) << "\"" @@ -309,13 +313,14 @@ String ReportNode::AsJSON() const { // value we want to print. Instead we construct the json by hand because it // is easier. s << "{"; + s << "\"calls\":["; for (size_t i = 0; i < calls.size(); i++) { size_t j = 0; s << "{"; for (const auto& kv : calls[i]) { s << "\"" << kv.first << "\":"; - print_metric(s, kv.second); + metric_as_json(s, kv.second); if (j < calls[i].size() - 1) { s << ","; } @@ -326,7 +331,8 @@ String ReportNode::AsJSON() const { s << ","; } } - s << "],"; + s << "],"; // end calls + s << "\"device_metrics\":{"; size_t i = 0; for (const auto& dev_kv : device_metrics) { @@ -334,7 +340,7 @@ String ReportNode::AsJSON() const { s << "\"" << dev_kv.first << "\":{"; for (const auto& metric_kv : dev_kv.second) { s << "\"" << metric_kv.first << "\":"; - print_metric(s, metric_kv.second); + metric_as_json(s, metric_kv.second); if (j < dev_kv.second.size() - 1) { s << ","; } @@ -346,7 +352,20 @@ String ReportNode::AsJSON() const { } i++; } - s << "}}"; + s << "},"; // end device metrics + + s << "\"configuration\":{"; + size_t k = 0; + for (const auto& kv : configuration) { + s << "\"" << kv.first << "\":"; + metric_as_json(s, kv.second); + if (k < configuration.size() - 1) { + s << ","; + } + k++; + } + s << "}"; // end configuration + s << "}"; return s.str(); } @@ -395,6 +414,35 @@ ObjectRef AggregateMetric(const std::vector& metrics) { } } +static String print_metric(ObjectRef metric) { + std::string val; + if (metric.as()) { + std::stringstream s; + s.imbue(std::locale("")); // for 1000s seperators + s << std::fixed << metric.as()->value; + val = s.str(); + } else if (metric.as()) { + std::stringstream s; + s.imbue(std::locale("")); // for 1000s seperators + s << std::fixed << std::setprecision(2) << metric.as()->microseconds; + val = s.str(); + } else if (metric.as()) { + std::stringstream s; + s << std::fixed << std::setprecision(2) << metric.as()->percent; + val = s.str(); + } else if (metric.as()) { + std::stringstream s; + s.imbue(std::locale("")); // for 1000s seperators + s << std::setprecision(2) << metric.as()->ratio; + val = s.str(); + } else if (metric.as()) { + val = Downcast(metric); + } else { + LOG(FATAL) << "Cannot print metric of type " << metric->GetTypeKey(); + } + return val; +} + String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes std::vector> aggregated_calls; @@ -533,30 +581,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // fill empty data with empty strings cols[i].push_back(""); } else { - std::string val; - if ((*it).second.as()) { - std::stringstream s; - s.imbue(std::locale("")); // for 1000s seperators - s << std::fixed << (*it).second.as()->value; - val = s.str(); - } else if ((*it).second.as()) { - std::stringstream s; - s.imbue(std::locale("")); // for 1000s seperators - s << std::fixed << std::setprecision(2) << (*it).second.as()->microseconds; - val = s.str(); - } else if ((*it).second.as()) { - std::stringstream s; - s << std::fixed << std::setprecision(2) << (*it).second.as()->percent; - val = s.str(); - } else if ((*it).second.as()) { - std::stringstream s; - s.imbue(std::locale("")); // for 1000s seperators - s << std::setprecision(2) << (*it).second.as()->ratio; - val = s.str(); - } else if ((*it).second.as()) { - val = Downcast((*it).second); - } - cols[i].push_back(val); + cols[i].push_back(print_metric((*it).second)); } } } @@ -592,6 +617,12 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con } s << std::endl; } + + // Add configuration information. It will not be aligned with the columns. + s << std::endl << "Configuration" << std::endl << "-------------" << std::endl; + for (auto kv : configuration) { + s << kv.first << ": " << print_metric(kv.second) << std::endl; + } return s.str(); } @@ -599,7 +630,7 @@ std::string DeviceString(Device dev) { return DeviceName(dev.device_type) + std::to_string(dev.device_id); } -Report Profiler::Report(bool aggregate, bool sort) { +Report Profiler::Report() { // sync all timers and normalize rows std::vector> rows; for (auto& cf : calls_) { @@ -638,14 +669,16 @@ Report Profiler::Report(bool aggregate, bool sort) { converted_rows.push_back(row); } - return profiling::Report(converted_rows, device_metrics); + return profiling::Report(converted_rows, device_metrics, configuration_); } Report::Report(Array> calls, - Map> device_metrics) { + Map> device_metrics, + Map configuration) { auto node = make_object(); node->calls = std::move(calls); node->device_metrics = std::move(device_metrics); + node->configuration = std::move(configuration); data_ = std::move(node); } @@ -697,6 +730,7 @@ Report Report::FromJSON(String json) { std::string key; Array> calls; Map> device_metrics; + Map configuration; reader.BeginObject(); while (reader.NextObjectItem(&key)) { @@ -713,10 +747,12 @@ Report Report::FromJSON(String json) { device_metrics.Set(device_name, parse_metrics(&reader)); } // reader.EndObject(); + } else if (key == "configuration") { + configuration = parse_metrics(&reader); } } - return Report(calls, device_metrics); + return Report(calls, device_metrics, configuration); } TVM_REGISTER_OBJECT_TYPE(DurationNode); @@ -855,8 +891,9 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, TVM_REGISTER_GLOBAL("runtime.profiling.Report") .set_body_typed([](Array> calls, - Map> device_metrics) { - return Report(calls, device_metrics); + Map> device_metrics, + Map configuration) { + return Report(calls, device_metrics, configuration); }); TVM_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) { diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 393d1b399878f..0ace910b5c539 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -58,9 +58,9 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, // on remotes, we accept a nullptr for collectors. if (collectors.defined()) { std::vector cs(collectors.begin(), collectors.end()); - prof_ = profiling::Profiler(devices, cs); + prof_ = profiling::Profiler(devices, cs, {{String("Executor"), String("VM")}}); } else { - prof_ = profiling::Profiler(devices, {}); + prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); } auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self); @@ -77,7 +77,7 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, return report; }); } else if (name == "profile_rpc") { - // We cannot return a Report over RPC because TMV RPC mechanism only + // We cannot return a Report over RPC because TVM RPC mechanism only // supports a subset of Object classes. Instead we serialize it on the // remote (here) and deserialize it on the other end. return TypedPackedFunc([sptr_to_self, this](std::string arg_name) { diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 29a8414337756..ab22bd2b9c481 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -69,6 +69,7 @@ def test_vm(target, dev): assert "Total" in str(report) assert "AllocTensorReg" in str(report) assert "AllocStorage" in str(report) + assert report.configuration["Executor"] == "VM" csv = read_csv(report) assert "Hash" in csv.keys() @@ -102,6 +103,7 @@ def test_graph_executor(target, dev): assert "fused_nn_softmax" in str(report) assert "Total" in str(report) assert "Hash" in str(report) + assert "Graph" in str(report) @tvm.testing.parametrize_targets("cuda", "llvm") @@ -147,6 +149,7 @@ def test_json(): parsed = json.loads(report.json()) assert "device_metrics" in parsed assert "calls" in parsed + assert "configuration" in parsed assert "Duration (us)" in parsed["calls"][0] assert "microseconds" in parsed["calls"][0]["Duration (us)"] assert len(parsed["calls"]) > 0 From 68dcecc926f890429a8f2cba9ce55eab6a18fa6e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Jun 2022 20:02:18 -0700 Subject: [PATCH 054/181] [MetaSchedule] Evo Independence from TaskScheduler (#11590) Per discussion with @Kathryn-cat, we realized that the current API design could be verbose if we only want to tune a single task, in which case a dummy task scheduler still needs to be established to supply `EvolutionarySearch` with proper `CostModel` and `Database`. This PR fixes this UX issue. --- include/tvm/meta_schedule/search_strategy.h | 17 +- include/tvm/meta_schedule/task_scheduler.h | 20 +-- include/tvm/meta_schedule/tune_context.h | 2 - .../search_strategy/search_strategy.py | 24 ++- .../task_scheduler/gradient_based.py | 10 +- .../task_scheduler/round_robin.py | 10 +- .../task_scheduler/task_scheduler.py | 10 +- .../measure_callback/add_to_database.cc | 5 +- .../search_strategy/evolutionary_search.cc | 148 +++++++++--------- .../search_strategy/replay_func.cc | 48 +++--- .../search_strategy/replay_trace.cc | 63 ++++---- .../search_strategy/search_strategy.cc | 7 + .../task_scheduler/gradient_based.cc | 7 +- .../task_scheduler/round_robin.cc | 7 +- .../task_scheduler/task_scheduler.cc | 6 +- .../test_meta_schedule_measure_callback.py | 22 ++- .../test_meta_schedule_search_strategy.py | 93 +++++------ .../test_meta_schedule_task_scheduler.py | 60 +++---- 18 files changed, 298 insertions(+), 261 deletions(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 6895673a04cc3..139de7c99d042 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -113,12 +113,16 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Pre-tuning for the search strategy. - * \param design_spaces The design spaces for pre-tuning. + * \param design_spaces The design spaces used during tuning process. + * \param database The database used during tuning process. + * \param cost_model The cost model used during tuning process. * \note Pre-tuning is supposed to be called before the tuning process and after the * initialization. Because the search strategy is stateful, we can always call pretuning * and reset the search strategy. */ - virtual void PreTuning(const Array& design_spaces) = 0; + virtual void PreTuning(const Array& design_spaces, + const Optional& database, + const Optional& cost_model) = 0; /*! * \brief Post-tuning for the search strategy. @@ -159,7 +163,8 @@ class PySearchStrategyNode : public SearchStrategyNode { * \brief The function type of `PreTuning` method. * \param design_spaces The design spaces for pre-tuning. */ - using FPreTuning = runtime::TypedPackedFunc&)>; + using FPreTuning = runtime::TypedPackedFunc&, const Optional&, const Optional&)>; /*! \brief The function type of `PostTuning` method. */ using FPostTuning = runtime::TypedPackedFunc; /*! @@ -199,10 +204,8 @@ class PySearchStrategyNode : public SearchStrategyNode { this->f_initialize_with_tune_context(context); } - void PreTuning(const Array& design_spaces) final { - ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; - this->f_pre_tuning(design_spaces); - } + void PreTuning(const Array& design_spaces, const Optional& database, + const Optional& cost_model) final; void PostTuning() final { ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 7453c2b484b90..5953a2c3e42b1 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -74,13 +74,13 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief The runner of the scheduler. */ Runner runner{nullptr}; /*! \brief The database of the scheduler. */ - Database database{nullptr}; - /*! \brief The maximum number of trials allowed. */ - int max_trials; + Optional database; /*! \brief The cost model of the scheduler. */ Optional cost_model; /*! \brief The list of measure callbacks of the scheduler. */ Array measure_callbacks; + /*! \brief The maximum number of trials allowed. */ + int max_trials; /*! \brief The number of trials already conducted. */ int num_trials_already; /*! \brief The tuning task's logging function. t*/ @@ -94,9 +94,9 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); - v->Visit("max_trials", &max_trials); v->Visit("cost_model", &cost_model); v->Visit("measure_callbacks", &measure_callbacks); + v->Visit("max_trials", &max_trials); v->Visit("num_trials_already", &num_trials_already); // `logging_func` is not visited } @@ -243,10 +243,10 @@ class TaskScheduler : public runtime::ObjectRef { TVM_DLL static TaskScheduler RoundRobin(Array tasks, // Builder builder, // Runner runner, // - Database database, // - int max_trials, // + Optional database, // Optional cost_model, // Optional> measure_callbacks, // + int max_trials, // PackedFunc logging_func); /*! * \brief Create a task scheduler that fetches tasks in a gradient based fashion. @@ -268,10 +268,10 @@ class TaskScheduler : public runtime::ObjectRef { Array task_weights, // Builder builder, // Runner runner, // - Database database, // - int max_trials, // + Optional database, // Optional cost_model, // Optional> measure_callbacks, // + int max_trials, // PackedFunc logging_func, // double alpha, // int window_size, // @@ -297,10 +297,10 @@ class TaskScheduler : public runtime::ObjectRef { Array tasks, // Builder builder, // Runner runner, // - Database database, // - int max_trials, // + Optional database, // Optional cost_model, // Optional> measure_callbacks, // + int max_trials, // PackedFunc logging_func, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index faa24fc99f4ce..d63fb819f3639 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -61,8 +61,6 @@ class TuneContextNode : public runtime::Object { /*! \brief The number of threads to be used. */ int num_threads; - /*! \brief The task scheduler that owns the tune context */ - const TaskSchedulerNode* task_scheduler; /*! \brief Whether the tuning task has been stopped or finished. */ bool is_terminated; /*! \brief The measure candidates. */ diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 07c47f01d1c55..14b46a0785f1d 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -18,7 +18,7 @@ Meta Schedule search strategy that generates the measure candidates for measurement. """ -from typing import Callable, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, List, Optional from tvm._ffi import register_object from tvm.runtime import Object @@ -29,6 +29,8 @@ from ..runner import RunnerResult if TYPE_CHECKING: + from ..cost_model import CostModel + from ..database import Database from ..tune_context import TuneContext @@ -87,15 +89,29 @@ def initialize_with_tune_context(self, context: "TuneContext") -> None: self, context ) - def pre_tuning(self, design_spaces: List[Schedule]) -> None: + def pre_tuning( + self, + design_spaces: List[Schedule], + database: Optional["Database"] = None, + cost_model: Optional["CostModel"] = None, + ) -> None: """Pre-tuning for the search strategy. Parameters ---------- design_spaces : List[Schedule] - The design spaces for pre-tuning. + The design spaces used during tuning process. + database : Optional[Database] = None + The database used during tuning process. + cost_model : Optional[CostModel] = None + The cost model used during tuning process. """ - _ffi_api.SearchStrategyPreTuning(self, design_spaces) # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyPreTuning( # type: ignore # pylint: disable=no-member + self, + design_spaces, + database, + cost_model, + ) def post_tuning(self) -> None: """Post-tuning for the search strategy.""" diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index 6234449bf09b9..20d32dd1c59f9 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -45,11 +45,11 @@ def __init__( task_weights: List[float], builder: Builder, runner: Runner, - database: Database, - max_trials: int, *, + database: Database, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, + max_trials: int, alpha: float = 0.2, window_size: int = 3, seed: int = -1, @@ -68,12 +68,12 @@ def __init__( The runner. database : Database The database. - max_trials : int - The maximum number of trials to run. cost_model : CostModel, default None. The cost model of the scheduler. measure_callbacks : Optional[List[MeasureCallback]] = None The list of measure callbacks of the scheduler. + max_trials : int + The maximum number of trials to run. alpha : float = 0.2 The parameter alpha in gradient computation. window_size : int = 3 @@ -88,9 +88,9 @@ def __init__( builder, runner, database, - max_trials, cost_model, measure_callbacks, + max_trials, make_logging_func(logger), alpha, window_size, diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index a461358283949..ed395643bbaae 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -60,11 +60,11 @@ def __init__( task_weights: List[float], builder: Builder, runner: Runner, - database: Database, - max_trials: int, *, + database: Database, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, + max_trials: int, ) -> None: """Constructor. @@ -80,12 +80,12 @@ def __init__( The runner. database : Database The database. - max_trials : int - The maximum number of trials. cost_model : Optional[CostModel] The cost model. measure_callbacks: Optional[List[MeasureCallback]] The list of measure callbacks of the scheduler. + max_trials : int + The maximum number of trials. """ del task_weights self.__init_handle_by_constructor__( @@ -94,8 +94,8 @@ def __init__( builder, runner, database, - max_trials, cost_model, measure_callbacks, + max_trials, make_logging_func(logger), ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 4454078a6f16d..3d57a6b01b9db 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -31,7 +31,6 @@ from ..tune_context import TuneContext from ..utils import make_logging_func - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -177,9 +176,9 @@ class PyTaskScheduler: "builder", "runner", "database", - "max_trials", "cost_model", "measure_callbacks", + "max_trials", ], "methods": [ "tune", @@ -195,18 +194,19 @@ def __init__( tasks: List[TuneContext], builder: Builder, runner: Runner, - database: Database, - max_trials: int, + *, + database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, + max_trials: int, ): self.tasks = tasks self.builder = builder self.runner = runner self.database = database - self.max_trials = max_trials self.cost_model = cost_model self.measure_callbacks = measure_callbacks + self.max_trials = max_trials def tune(self) -> None: """Auto-tuning.""" diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 20581f4630a63..0988da0414e2a 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -27,8 +27,11 @@ class AddToDatabaseNode : public MeasureCallbackNode { const Array& measure_candidates, const Array& builder_results, const Array& runner_results) final { + if (!task_scheduler->database.defined()) { + return; + } TuneContext task = task_scheduler->tasks[task_id]; - Database database = task_scheduler->database; + Database database = task_scheduler->database.value(); Workload workload = database->CommitWorkload(task->mod.value()); Target target = task->target.value(); ICHECK_EQ(runner_results.size(), measure_candidates.size()); diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index bdef26ef876e5..8b36a95217046 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -246,13 +246,41 @@ class EvolutionarySearchNode : public SearchStrategyNode { int ed; /*! \brief The counter of returning empty results. */ int num_empty_iters; - - explicit State(EvolutionarySearchNode* self, Array design_spaces) + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; + /*! + * \brief The workloads that are already measured. + * TODO(junrushao1994): add records from the database to avoid re-measuring. + * */ + IRModuleSet measured_workloads_; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + explicit State(EvolutionarySearchNode* self, Array design_spaces, Database database, + CostModel cost_model) : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter), - num_empty_iters(0) {} + num_empty_iters(0) { + const TuneContextNode* ctx = self->context_; + IRModule mod = ctx->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod)); + this->per_thread_data_.resize(ctx->num_threads); + for (PerThreadData& data : this->per_thread_data_) { + data.mod = DeepCopyIRModule(mod); + data.rand_state = ForkSeed(&self->rand_state_); + } + this->database_ = database; + this->cost_model_ = cost_model; + this->token_ = database->CommitWorkload(mod); + } /*! * \brief Pick up best candidates from database. @@ -293,33 +321,10 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The tuning context of the evolutionary search strategy. */ const TuneContextNode* context_{nullptr}; - /*! \brief The target for the workload. */ - Target target_{nullptr}; - /*! \brief The metadata of the function arguments. */ - Array args_info_{nullptr}; - /*! \brief A Database for selecting useful candidates. */ - Database database_{nullptr}; - /*! \brief A cost model helping to explore the search space */ - CostModel cost_model_{nullptr}; - /*! \brief The postprocessors. */ - Array postprocs_{nullptr}; - /*! \brief Mutators and their probability mass */ - Map mutator_probs_{nullptr}; - /*! \brief The number of threads to use. To be initialized with TuneContext. */ - int num_threads_; /*! \brief The random state. To be initialized with TuneContext. */ TRandState rand_state_; - /*! \brief Pre thread data including module to be tuned and random state. */ - std::vector per_thread_data_; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; - /*! \brief The token registered for the given workload in database. */ - Workload token_{nullptr}; - /*! - * \brief The workloads that are already measured. - * TODO(junrushao1994): add records from the database to avoid re-measuring. - * */ - IRModuleSet measured_workloads_; /*** Configuration: global ***/ /*! \brief The number of trials per iteration. */ @@ -351,15 +356,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { // `context_` is not visited - // `target_` is not visited - // `args_info_` is not visited - // `database` is not visited - // `cost_model` is not visited - // `postprocs` is not visited - // `mutator_probs_` is not visited - // `num_threads` is not visited // `rand_state_` is not visited - // `per_thread_data_` is not visited // `state_` is not visited /*** Configuration: global ***/ @@ -386,39 +383,41 @@ class EvolutionarySearchNode : public SearchStrategyNode { CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; CHECK(context->target.defined()) << "Target must be defined!"; this->context_ = context.get(); - this->target_ = context->target.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); - this->mutator_probs_ = context->mutator_probs; - this->postprocs_ = context->postprocs; - this->num_threads_ = context->num_threads; this->rand_state_ = ForkSeed(&context->rand_state); - CHECK(context->task_scheduler != nullptr) - << "ValueError: TaskScheduler is not defined in TuneContext"; - this->cost_model_ = context->task_scheduler->cost_model.value(); - this->database_ = context->task_scheduler->database; - this->token_ = this->database_->CommitWorkload(context->mod.value()); - this->per_thread_data_.resize(this->num_threads_); - for (const auto& kv : this->mutator_probs_) { + for (const auto& kv : context->mutator_probs) { double mass = kv.second->value; TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); } - for (PerThreadData& data : this->per_thread_data_) { - data.mod = DeepCopyIRModule(context->mod.value()); - data.rand_state = ForkSeed(&this->rand_state_); - } this->state_.reset(); } - void PreTuning(const Array& design_spaces) final { + void PreTuning(const Array& design_spaces, const Optional& database, + const Optional& cost_model) final { ICHECK(!design_spaces.empty()); + CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; + CHECK(database.defined()) + << "ValueError: Database is not supplied in PreTuning. Evolutionary" + "search algorithm requires a database to be present, so that it " + "could sample from previously-explored population. If you do not " + "intent to store data on disk, please use `tvm.meta_schedule.testing.DummyDatabase`"; + CHECK(cost_model.defined()) + << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search " + "algorithm expects a cost model to filter out potentially less efficient kernels. If " + "you do not expect a cost model to help, please use " + "`tvm.meta_schedule.cost_model.RandomModel`"; + if (this->state_ != nullptr) { + TVM_PY_LOG(WARNING, this->context_->logging_func) + << "EvolutionarySearch is already initialized."; + this->state_.reset(); + } ICHECK(this->state_ == nullptr); - // Change to traces Array design_space_traces; design_space_traces.reserve(design_spaces.size()); for (const Schedule& space : design_spaces) { design_space_traces.push_back(space->trace().value()->Simplified(true)); } - this->state_ = std::make_unique(this, design_space_traces); + this->state_ = + std::make_unique(this, design_space_traces, database.value(), cost_model.value()); } void PostTuning() final { @@ -442,16 +441,16 @@ class EvolutionarySearchNode : public SearchStrategyNode { std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { std::vector measured_traces; measured_traces.reserve(num); - Array top_records = self->database_->GetTopK(self->token_, num); + Array top_records = this->database_->GetTopK(this->token_, num); for (TuningRecord record : top_records) { measured_traces.push_back(record->trace); } int actual_num = measured_traces.size(); - ThreadedTraceApply pp(self->postprocs_); + ThreadedTraceApply pp(self->context_->postprocs); std::vector results(actual_num, Schedule{nullptr}); auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, int trace_id) -> void { - PerThreadData& data = self->per_thread_data_.at(thread_id); + PerThreadData& data = this->per_thread_data_.at(thread_id); TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; tir::Trace trace = measured_traces.at(trace_id); @@ -464,17 +463,17 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu throw; } }; - support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured); + support::parallel_for_dynamic(0, actual_num, self->context_->num_threads, f_proc_measured); return results; } std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { - ThreadedTraceApply pp(self->postprocs_); + ThreadedTraceApply pp(self->context_->postprocs); std::vector out_schs; while (static_cast(out_schs.size()) < self->init_min_unmeasured) { std::vector results(num, Schedule{nullptr}); auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { - PerThreadData& data = self->per_thread_data_.at(thread_id); + PerThreadData& data = this->per_thread_data_.at(thread_id); TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; Schedule& result = results.at(trace_id); @@ -485,7 +484,7 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu result = sch.value(); } }; - support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + support::parallel_for_dynamic(0, num, self->context_->num_threads, f_proc_unmeasured); for (int i = 0; i < num; i++) { if (results[i].defined()) { out_schs.push_back(results[i]); @@ -501,14 +500,14 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( std::vector population, int num) { ICHECK_GT(num, 0); // The heap to record best schedule, we do not consider schedules that are already measured - IRModuleSet exists = self->measured_workloads_; + IRModuleSet exists = this->measured_workloads_; SizedHeap heap(num); for (int iter = 0;; ++iter) { // Predict normalized score with the cost model, std::vector scores = PredictNormalizedScore(population, // GetRef(self->context_), // - self->cost_model_, // - self->args_info_); + this->cost_model_, // + this->args_info_); ICHECK_EQ(scores.size(), population.size()); for (int i = 0, n = population.size(); i < n; ++i) { Schedule sch = population.at(i); @@ -524,18 +523,18 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( if (iter == self->genetic_num_iters) { break; } - // Set threaded samplers, with probability from predicated normalized throughputs - for (PerThreadData& data : self->per_thread_data_) { - data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); + // Set threaded samplers, with probability from predicated normalized throughput + for (PerThreadData& data : this->per_thread_data_) { + data.Set(scores, self->genetic_mutate_prob, self->context_->mutator_probs); } - ThreadedTraceApply pp(self->postprocs_); + ThreadedTraceApply pp(self->context_->postprocs); ConcurrentBitmask cbmask(self->population_size); std::vector next_population(self->population_size, Schedule{nullptr}); // The worker function auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, int trace_id) { // Prepare samplers - PerThreadData& data = self->per_thread_data_.at(thread_id); + PerThreadData& data = this->per_thread_data_.at(thread_id); TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; std::function& trace_sampler = data.trace_sampler; @@ -567,7 +566,8 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( result = population.at(sampled_trace_id); } }; - support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); + support::parallel_for_dynamic(0, self->population_size, self->context_->num_threads, + f_find_candidate); population.swap(next_population); TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); @@ -607,7 +607,7 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); std::vector results; results.reserve(num); - IRModuleSet& measured_workloads = self->measured_workloads_; + IRModuleSet& measured_workloads = this->measured_workloads_; for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { bool has_best = i_bests < static_cast(bests.size()); bool has_rand = i_rands < static_cast(rands.size()); @@ -677,7 +677,7 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure return NullOpt; } } - return AssembleCandidates(picks, self->args_info_); + return AssembleCandidates(picks, this->args_info_); } void EvolutionarySearchNode::State::NotifyRunnerResults( @@ -713,6 +713,12 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, / return SearchStrategy(n); } +class EvolutionarySearch : public SearchStrategy { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(EvolutionarySearch, SearchStrategy, + EvolutionarySearchNode); +}; + TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") .set_body_typed(SearchStrategy::EvolutionarySearch); diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 878c872a65fe2..1aaaaa09e8ab8 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -32,8 +32,14 @@ class ReplayFuncNode : public SearchStrategyNode { int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; - explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) { + const TuneContextNode* ctx = self->context_; + ICHECK(ctx); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(ctx->mod.value())); + } inline Optional> GenerateMeasureCandidates(); inline void NotifyRunnerResults(const Array& results); @@ -44,14 +50,8 @@ class ReplayFuncNode : public SearchStrategyNode { /*! \brief The number of total trials. */ int max_trials_per_task; - /*! \brief The module to be tuned. */ - IRModule mod_{nullptr}; - /*! \brief The metadata of the function arguments. */ - Array args_info_{nullptr}; - /*! \brief The post processors */ - Array postprocs_{nullptr}; - /*! \brief The space generator for measure candidates generation. */ - SpaceGenerator space_generator_{nullptr}; + /*! \brief The tuning context of the search strategy. */ + const TuneContextNode* context_{nullptr}; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The state of the search strategy. */ @@ -60,10 +60,7 @@ class ReplayFuncNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("max_trials_per_task", &max_trials_per_task); - // `space_generator_` is not visited - // `mod_` is not visited - // `args_info_` is not visited - // `num_threads_` is not visited + // `context_` is not visited. // `rand_state_` is not visited // `state_` is not visited } @@ -72,15 +69,21 @@ class ReplayFuncNode : public SearchStrategyNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& context) final { - this->space_generator_ = context->space_generator.value(); - this->mod_ = context->mod.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); - this->postprocs_ = context->postprocs; + CHECK(context->space_generator.defined()) + << "ValueError: TuneContext.space_generator is not defined"; + CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined"; + this->context_ = context.get(); this->rand_state_ = ForkSeed(&context->rand_state); this->state_.reset(); } - void PreTuning(const Array& design_spaces) final { + void PreTuning(const Array& design_spaces, const Optional& database, + const Optional& cost_model) final { + CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; + if (this->state_ != nullptr) { + TVM_PY_LOG(WARNING, this->context_->logging_func) << "ReplayFunc is already initialized."; + this->state_.reset(); + } ICHECK(this->state_ == nullptr); this->state_ = std::make_unique(this); } @@ -109,21 +112,24 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC } ed = std::min(ed, self->max_trials_per_task); Array result; + const TuneContextNode* ctx = self->context_; + ICHECK(ctx); + IRModule mod = ctx->mod.value(); for (int i = st; i < ed; i++) { for (;;) { - Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + Array schs = ctx->space_generator.value()->GenerateDesignSpace(mod); int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); tir::Schedule sch = schs[design_space_index]; sch->EnterPostproc(); bool failed = false; - for (const Postproc& proc : self->postprocs_) { + for (const Postproc& proc : ctx->postprocs) { if (!proc->Apply(sch)) { failed = true; break; } } if (!failed) { - result.push_back(MeasureCandidate(sch, self->args_info_)); + result.push_back(MeasureCandidate(sch, this->args_info_)); break; } } diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index f17c5d6c4eb3e..13f32a744e3a0 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -35,8 +35,22 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; + /*! \brief The module to be tuned. */ + Array per_thread_mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + explicit State(ReplayTraceNode* self, Array design_spaces) - : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) { + const TuneContextNode* ctx = self->context_; + ICHECK(ctx); + IRModule mod = ctx->mod.value(); + this->per_thread_mod_.reserve(ctx->num_threads); + for (int i = 0; i < ctx->num_threads; i++) { + this->per_thread_mod_.push_back(DeepCopyIRModule(mod)); + } + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(mod)); + } inline Optional> GenerateMeasureCandidates(); inline void NotifyRunnerResults(const Array& results); @@ -47,14 +61,8 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The number of total trials. */ int max_trials_per_task; - /*! \brief The module to be tuned. */ - Array per_thread_mod_{nullptr}; - /*! \brief The metadata of the function arguments. */ - Array args_info_{nullptr}; - /*! \brief The post processors */ - Array postprocs_{nullptr}; - /*! \brief The number of threads to use. -1 means using logical cpu number. */ - int num_threads_ = -1; + /*! \brief The tuning context of the search strategy. */ + const TuneContextNode* context_{nullptr}; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The state of the search strategy. */ @@ -63,10 +71,7 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("max_trials_per_task", &max_trials_per_task); - // `per_thread_mod_` is not visited - // `args_info_` is not visited - // `postprocs_` is not visited - // `num_threads_` is not visited + // `context_` is not visited. // `rand_state_` is not visited // `state_` is not visited } @@ -75,22 +80,20 @@ class ReplayTraceNode : public SearchStrategyNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; - this->num_threads_ = context->num_threads; - - this->per_thread_mod_.reserve(this->num_threads_); - for (int i = 0; i < this->num_threads_; i++) { - this->per_thread_mod_.push_back(DeepCopyIRModule(context->mod.value())); - } - - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); - this->postprocs_ = context->postprocs; + CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined"; + this->context_ = context.get(); this->rand_state_ = ForkSeed(&context->rand_state); this->state_.reset(); } - void PreTuning(const Array& design_spaces) final { + void PreTuning(const Array& design_spaces, const Optional& database, + const Optional& cost_model) final { ICHECK(!design_spaces.empty()); + CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; + if (this->state_ != nullptr) { + TVM_PY_LOG(WARNING, this->context_->logging_func) << "RelayTrace is already initialized."; + this->state_.reset(); + } ICHECK(this->state_ == nullptr); Array design_space_traces; design_space_traces.reserve(design_spaces.size()); @@ -124,24 +127,26 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure } ed = std::min(ed, self->max_trials_per_task); ICHECK_LT(st, ed); - std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + const TuneContextNode* ctx = self->context_; + ICHECK(ctx); + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, ctx->num_threads); Array per_task_result(ed - st, MeasureCandidate{nullptr}); - ThreadedTraceApply pp(self->postprocs_); + ThreadedTraceApply pp(ctx->postprocs); auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; - IRModule mod = self->per_thread_mod_[thread_id]; + IRModule mod = this->per_thread_mod_[thread_id]; for (;;) { int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); tir::Trace trace = design_spaces[design_space_index]; tir::Trace new_trace = tir::Trace(trace->insts, {}); if (Optional sch = pp.Apply(mod, new_trace, &rand_state)) { - per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_)); + per_task_result.Set(task_id, MeasureCandidate(sch.value(), this->args_info_)); break; } } }; - support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); + support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker); return per_task_result; } diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index fefe8dfce76e9..a6a1100cebe60 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -28,6 +28,13 @@ MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) data_ = std::move(n); } +void PySearchStrategyNode::PreTuning(const Array& design_spaces, + const Optional& database, + const Optional& cost_model) { + ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; + this->f_pre_tuning(design_spaces, database, cost_model); +} + SearchStrategy SearchStrategy::PySearchStrategy( PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PySearchStrategyNode::FPreTuning f_pre_tuning, // diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index f8cc9d5514941..73d191f593fec 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -189,10 +189,10 @@ TaskScheduler TaskScheduler::GradientBased(Array tasks, Array task_weights, // Builder builder, // Runner runner, // - Database database, // - int max_trials, // + Optional database, // Optional cost_model, // Optional> measure_callbacks, // + int max_trials, // PackedFunc logging_func, // double alpha, // int window_size, // @@ -227,9 +227,6 @@ TaskScheduler TaskScheduler::GradientBased(Array tasks, n->best_time_cost_per_task_ = std::vector(n_tasks, 1e100); n->num_rounds_already_ = 0; support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); - for (const TuneContext& task : tasks) { - task->task_scheduler = n.get(); - } return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 446b11837930b..ea22878840aff 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -58,10 +58,10 @@ class RoundRobinNode final : public TaskSchedulerNode { TaskScheduler TaskScheduler::RoundRobin(Array tasks, // Builder builder, // Runner runner, // - Database database, // - int max_trials, // + Optional database, // Optional cost_model, // Optional> measure_callbacks, // + int max_trials, // PackedFunc logging_func) { ObjectPtr n = make_object(); n->tasks = tasks; @@ -74,9 +74,6 @@ TaskScheduler TaskScheduler::RoundRobin(Array tasks, n->logging_func = logging_func; n->num_trials_already = 0; n->task_id = -1; - for (const TuneContext& task : tasks) { - task->task_scheduler = n.get(); - } return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index fd1d95cd1f19b..25867fb4f3bbf 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -117,7 +117,7 @@ void TaskSchedulerNode::InitializeTask(int task_id) { << tir::AsTVMScript(sch->mod()) << "\n" << Concat(trace->AsPython(false), "\n"); } - task->search_strategy.value()->PreTuning(design_spaces); + task->search_strategy.value()->PreTuning(design_spaces, database, cost_model); } void TaskSchedulerNode::Tune() { @@ -203,10 +203,10 @@ TaskScheduler TaskScheduler::PyTaskScheduler( Array tasks, // Builder builder, // Runner runner, // - Database database, // - int max_trials, // + Optional database, // Optional cost_model, // Optional> measure_callbacks, // + int max_trials, // PackedFunc logging_func, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index a1b188930f86a..298b51e0158e5 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -16,12 +16,10 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import re -from random import random from typing import List import pytest import tvm -from tvm.ir import IRModule, assert_structural_equal from tvm.meta_schedule.builder import BuilderResult from tvm.meta_schedule.measure_callback import PyMeasureCallback from tvm.meta_schedule.runner import RunnerResult @@ -66,7 +64,7 @@ def apply( results: List[RunnerResult], ) -> None: assert len(measure_candidates) == 1 - assert_structural_equal(measure_candidates[0].sch.mod, Matmul) + tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert ( len(builds) == 1 and builds[0].error_msg is None @@ -78,7 +76,14 @@ def apply( measure_callback = FancyMeasureCallback() measure_callback.apply( - RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), + RoundRobin( + tasks=[], + task_weights=[], + builder=DummyBuilder(), + runner=DummyRunner(), + database=DummyDatabase(), + max_trials=1, + ), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], @@ -102,7 +107,14 @@ def apply( measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( - RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), + RoundRobin( + tasks=[], + task_weights=[], + builder=DummyBuilder(), + runner=DummyRunner(), + database=DummyDatabase(), + max_trials=1, + ), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 94042dd753e0d..4eb8aac5a3314 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -123,43 +123,37 @@ def _schedule_matmul_small(sch: Schedule): num_trials_per_iter = 10 max_trials_per_task = 2000 + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - strategy = EvolutionarySearch( - num_trials_per_iter=num_trials_per_iter, - max_trials_per_task=max_trials_per_task, - population_size=5, - init_measured_ratio=0.1, - init_min_unmeasured=50, - genetic_num_iters=3, - genetic_mutate_prob=0.5, - genetic_max_fail_count=10, - eps_greedy=0.9, - ) context = TuneContext( mod=Matmul, - space_generator=ScheduleFn(sch_fn=_schedule_matmul_small), + space_generator=ScheduleFn( + sch_fn=_schedule_matmul_small, + ), + search_strategy=EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + max_trials_per_task=max_trials_per_task, + population_size=5, + init_measured_ratio=0.1, + init_min_unmeasured=50, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ), mutator_probs={ DummyMutator(): 1.0, }, target=tvm.target.Target("llvm"), num_threads=1, # because we are using a mutator from the python side ) - _scheduler = RoundRobin( - tasks=[context], - task_weights=[1.0], - builder=ms.builder.LocalBuilder(), - runner=ms.runner.LocalRunner(), + context.initialize() + strategy = context.search_strategy + strategy.pre_tuning( + context.space_generator.generate_design_space(context.mod), database=DummyDatabase(), cost_model=ms.cost_model.RandomModel(), - measure_callbacks=[], - max_trials=1, ) - context.space_generator.initialize_with_tune_context(context) - spaces = context.space_generator.generate_design_space(context.mod) - - strategy.initialize_with_tune_context(context) - strategy.pre_tuning(spaces) - (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: @@ -177,52 +171,46 @@ def _schedule_matmul_small(sch: Schedule): strategy.post_tuning() assert sum(num_trials_each_iter) == 25 assert num_trials_each_iter.count(0) < 5 - del _scheduler def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable = invalid-name] def _schedule_matmul_empty(sch: Schedule): return sch + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_per_iter = 10 max_trials_per_task = 100 - strategy = EvolutionarySearch( - num_trials_per_iter=num_trials_per_iter, - max_trials_per_task=max_trials_per_task, - population_size=5, - init_measured_ratio=0.1, - init_min_unmeasured=50, - genetic_num_iters=3, - genetic_mutate_prob=0.5, - genetic_max_fail_count=10, - eps_greedy=0.9, - ) context = TuneContext( mod=Matmul, - space_generator=ScheduleFn(sch_fn=_schedule_matmul_empty), + search_strategy=EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + max_trials_per_task=max_trials_per_task, + population_size=5, + init_measured_ratio=0.1, + init_min_unmeasured=50, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ), + space_generator=ScheduleFn( + sch_fn=_schedule_matmul_empty, + ), mutator_probs={ DummyMutator(): 1.0, }, target=tvm.target.Target("llvm"), - num_threads=1, # because we are using a mutator from the python side + num_threads=1, ) - _scheduler = RoundRobin( - tasks=[context], - task_weights=[1.0], - builder=ms.builder.LocalBuilder(), - runner=ms.runner.LocalRunner(), + context.initialize() + strategy = context.search_strategy + strategy.pre_tuning( + context.space_generator.generate_design_space(context.mod), database=DummyDatabase(), cost_model=ms.cost_model.RandomModel(), - measure_callbacks=[], - max_trials=1, ) - context.space_generator.initialize_with_tune_context(context) - spaces = context.space_generator.generate_design_space(context.mod) - - strategy.initialize_with_tune_context(context) - strategy.pre_tuning(spaces) - (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: @@ -239,7 +227,6 @@ def _schedule_matmul_empty(sch: Schedule): candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [1, 0, 0, 0, 0] - del _scheduler if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 025bbe4225b54..f24dc5fbbc1fd 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -17,7 +17,6 @@ """ Test Meta Schedule Task Scheduler """ import random -import sys import weakref from typing import Set @@ -108,7 +107,6 @@ def main( # type: ignore def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) - # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) @@ -118,7 +116,6 @@ def _schedule_matmul(sch: Schedule): def _schedule_batch_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k, t = sch.get_loops(block=block) - # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 2, 2, 2]) j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[2, 4, 64, 2]) k_0, k_1 = sch.split(loop=k, factors=[32, 32]) @@ -156,23 +153,22 @@ def next_task_id(self) -> int: def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 - sch_fn = ScheduleFn(sch_fn=_schedule_matmul) - replay = ReplayTrace(num_trials_per_iter, max_trials_per_task) - task = TuneContext( - MatmulModule, - target=tvm.target.Target("llvm"), - space_generator=sch_fn, - search_strategy=replay, - task_name="Test", - rand_state=42, - ) database = DummyDatabase() round_robin = RoundRobin( - [task], + [ + TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + task_name="Test", + rand_state=42, + ) + ], [1.0], - DummyBuilder(), - DummyRunner(), - database, + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task, ) @@ -212,10 +208,10 @@ def test_meta_schedule_task_scheduler_multiple(): database = DummyDatabase() round_robin = RoundRobin( tasks, - [1.0], - DummyBuilder(), - DummyRunner(), - database, + [1.0, 1.0, 1.0], + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, measure_callbacks=[measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) @@ -239,18 +235,23 @@ class NIETaskScheduler(PyTaskScheduler): pass with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"): - scheduler = NIETaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase(), 1) + scheduler = NIETaskScheduler( + tasks=[], + builder=DummyBuilder(), + runner=DummyRunner(), + database=DummyDatabase(), + max_trials=1, + ) scheduler.next_task_id() def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid-name - database = DummyDatabase() scheduler = MyTaskScheduler( [], - DummyBuilder(), - DummyRunner(), - database, + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, measure_callbacks=[ measure_callback.AddToDatabase(), ], @@ -262,7 +263,6 @@ def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name - num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ @@ -294,9 +294,9 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d database = DummyDatabase() scheduler = MyTaskScheduler( tasks, - DummyBuilder(), - DummyRunner(), - database, + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, measure_callbacks=[ measure_callback.AddToDatabase(), ], From a2ef144ea3aa8ae763c59cc596e73d6a89b3f046 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 7 Jun 2022 00:57:59 -0700 Subject: [PATCH 055/181] Refactor RewriteTensorize to prevent concurrent map updates (#11596) --- .../postproc/rewrite_tensorize.cc | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 1ad394e49c596..3df9075972963 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -28,10 +28,10 @@ namespace meta_schedule { using tir::BlockRV; using tir::LoopRV; -void ApplyTensorization(const tir::Schedule& sch, const String& func_name, - const tir::PrimFuncNode* func, bool vectorize_init_loop) { - std::vector>> jobs; - +void CollectTensorizationJobs( + const tir::Schedule& sch, const String& func_name, const tir::PrimFuncNode* func, + bool vectorize_init_loop, + std::vector>>* jobs) { tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); @@ -39,7 +39,7 @@ void ApplyTensorization(const tir::Schedule& sch, const String& func_name, tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { std::string block_name = block_sref->StmtAs()->name_hint; if (block_name.find("init") == std::string::npos) { - jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) { + jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { try { sch->Tensorize(block, intrin_name.value()); } catch (const std::exception& e) { @@ -47,7 +47,7 @@ void ApplyTensorization(const tir::Schedule& sch, const String& func_name, } }); } else if (vectorize_init_loop) { - jobs.emplace_back(block_name, [sch](tir::BlockRV block) { + jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { Array child_blocks = sch->GetChildBlocks(block); ICHECK(child_blocks.size() == 1); Array init_loops = sch->GetLoops(child_blocks[0]); @@ -58,12 +58,6 @@ void ApplyTensorization(const tir::Schedule& sch, const String& func_name, } } }); - - for (auto kv : jobs) { - tir::BlockRV block = sch->GetBlock(kv.first, func_name); - sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); - kv.second(block); - } } class RewriteTensorizeNode : public PostprocNode { @@ -81,13 +75,23 @@ class RewriteTensorizeNode : public PostprocNode { }; bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { + // The rewriting jobs, 3-tuple (block_name, func_name, job_func) + std::vector>> jobs; for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; if (const tir::PrimFuncNode* prim_func = base_func.as()) { - ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop); + CollectTensorizationJobs(sch, g_var->name_hint, prim_func, vectorize_init_loop, &jobs); } } + for (const auto& job : jobs) { + const String& block_name = std::get<0>(job); + const String& func_name = std::get<1>(job); + const auto& job_func = std::get<2>(job); + BlockRV block = sch->GetBlock(block_name, func_name); + sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); + job_func(block); + } return true; } From 70884e957aa5c8de9c02c25a14d30563d7300cb9 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 7 Jun 2022 00:58:14 -0700 Subject: [PATCH 056/181] fix uint case (#11597) --- src/relay/transforms/fold_explicit_padding.cc | 3 ++- tests/python/relay/test_pass_fold_explicit_padding.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index c60f36c7540e2..00162abc69f90 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -269,7 +269,8 @@ class SimplifyExplicitPad { } else if (node_map.count(avg_pool3d_)) { attrs = MakeAvgPoolAttrs(param, call_node->attrs.as()); } - } else if (node_map.count(max_pool_)) { + } + if (node_map.count(max_pool_)) { // Fold Padding and MaxPool only if pad_value is the min possible value for the dtype auto min_value = tvm::min_value(tvm::runtime::DataType(pad_value->data->dtype)); const FloatImmNode* maybe_min_float = min_value.as(); diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py index 41e2500d4ffa9..35354508a953a 100644 --- a/tests/python/relay/test_pass_fold_explicit_padding.py +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -228,8 +228,8 @@ def validate( # Check Pool pad folding when pad width on pad op is all zero. validate(max_pools, 1, [[0, 0], [0, 0], [0, 0]], float_min_val, [2, 0], "NCW", 2) - # Check MaxPool pad folding with int dtype - int_min_val = get_min_value("int32") + # Check MaxPool pad folding with uint dtype + int_min_val = get_min_value("uint8") validate( max_pools, 2, @@ -238,7 +238,7 @@ def validate( [2, 0, 0, 0], "NCHW", 2, - dtype="int32", + dtype="uint8", ) # Fold when original AvgPool has its own padding but count_include_pad=True validate( From 32a86f8304928f16286cd9ffe6d47abc6c4a5bb6 Mon Sep 17 00:00:00 2001 From: Altan Haan <3124994+altanh@users.noreply.github.com> Date: Tue, 7 Jun 2022 10:33:21 -0700 Subject: [PATCH 057/181] [TOPI] TE implementation of LSTM using scan (#11531) * TE implementation of LSTM in TOPI * docstring * lint * add injective tags where applicable --- python/tvm/topi/generic/nn.py | 16 ++ python/tvm/topi/nn/__init__.py | 1 + python/tvm/topi/nn/lstm.py | 235 +++++++++++++++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/lstm_python.py | 134 ++++++++++++ tests/python/topi/python/test_topi_lstm.py | 161 ++++++++++++++ 6 files changed, 548 insertions(+) create mode 100644 python/tvm/topi/nn/lstm.py create mode 100644 python/tvm/topi/testing/lstm_python.py create mode 100644 tests/python/topi/python/test_topi_lstm.py diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 4226c6caf23c9..80ea00ab01530 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -881,3 +881,19 @@ def schedule_correlation_nchw(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_lstm(outs): + """Schedule for LSTM + + Parameters + ---------- + outs : Array of Tensor + The outputs of LSTM (hidden states and cell states). + + Returns + ------- + sch: Schedule + The default schedule for LSTM. + """ + return _default_schedule(outs, False) diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index d3d00305a17b3..1dd922d76819c 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -51,3 +51,4 @@ from .space_to_batch_nd import * from .batch_to_space_nd import * from .loss import * +from .lstm import * diff --git a/python/tvm/topi/nn/lstm.py b/python/tvm/topi/nn/lstm.py new file mode 100644 index 0000000000000..b9723b5675d01 --- /dev/null +++ b/python/tvm/topi/nn/lstm.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""General LSTM implementation using TE scan.""" +from tvm import te, tir +from tvm.topi import tag + + +def lstm( + Xs, + Wi, + Wh, + Bi=None, + Bh=None, + h_init=None, + c_init=None, + proj=None, + p_i=None, + p_f=None, + p_o=None, + f_act=tir.sigmoid, + g_act=tir.tanh, + h_act=tir.tanh, + reverse=False, + weight_layout: str = "IFGO", +): + """General LSTM implemented using TE scan. + + Parameters + ---------- + Xs : te.Tensor + Input sequence with shape `(seq_len, batch_size, in_dim)` + Wi : te.Tensor + Input weight matrix with shape `(4 * hidden_dim, in_dim)`. The weights are packed according + to `weight_layout`. + Wh : te.Tensor + Hidden weight matrix with shape `(4 * hidden_dim, hidden_dim or proj_dim)`. Packed as `Wh`. + Bi : te.Tensor, optional + Input bias with shape `(4 * hidden_dim,)`, by default None. Packed as `Wh`. + Bh : te.Tensor, optional + Hidden bias with shape as `Bi`, by default None. Packed as `Wh`. + h_init : te.Tensor, optional + Initial hidden state with shape `(batch_size, hidden_dim or proj_dim)`, zero if None + c_init : te.Tensor, optional + Initial cell state with same shape as `h_init`, zero if None + proj : te.Tensor, optional + Projection matrix with shape `(proj_dim, hidden_dim)`, by default None + p_i, p_f, p_o : te.Tensor, optional + Peephole LSTM matrices with shape `(batch_size, hidden_dim)`, by default None + f_act, g_act, h_act : F, optional + Gate activation functions + reverse : bool, optional + Whether to process `Xs` in reverse, by default False + weight_layout : str, optional + The packed weight layout for gates, by default "IFGO". Note: I = input, F = forget, + G = cell, O = output. + + Returns + ------- + result : te.Tensor, te.Tensor + Tuple of hidden states (with shape `(seq_len, batch_size, hidden_dim or proj_dim)`), and + cell states (with shape `(seq_len, batch_size, hidden_dim)`). + """ + assert len(weight_layout) == 4 and sorted(weight_layout) == sorted( + "IFGO" + ), f'given weight layout "{weight_layout}" is not a permutation of "IFGO"' + + i_gate_idx = weight_layout.find("I") + f_gate_idx = weight_layout.find("F") + g_gate_idx = weight_layout.find("G") + o_gate_idx = weight_layout.find("O") + + seq_len, batch_size, in_dim = Xs.shape + assert ( + Wi.shape[0] % 4 == 0 + ), f"dim 0 of input weight should be 4 * hidden_dim, but {Wi.shape[0]} is not divisible by 4" + hidden_dim = Wi.shape[0] // 4 + proj_dim = hidden_dim + if proj is not None: + proj_dim = proj.shape[0] + + # te.scan uses up 1 element for the initial value + scan_len = seq_len + 1 + + # precompute input-hidden matmul outside the scan + ki = te.reduce_axis((0, in_dim), name="ki2h") + Xi2h = te.compute( + (seq_len * batch_size, 4 * hidden_dim), + lambda tb, ij: te.sum(Xs[(tb // batch_size), tb % batch_size, ki] * Wi[ij, ki], axis=ki), + name="Xi2h", + ) + if Bi is not None: + Xi2h = te.compute( + Xi2h.shape, lambda tb, ij: Xi2h[tb, ij] + Bi[ij], name="Xi2h_bias", tag=tag.INJECTIVE + ) + + h_state = te.placeholder((scan_len, batch_size, proj_dim), name="h_state") + c_state = te.placeholder((scan_len, batch_size, hidden_dim), name="c_state") + h_init = te.compute( + (1, batch_size, proj_dim), + lambda _, b, i: h_init[b, i] if h_init is not None else 0.0, + name="h_init", + ) + c_init = te.compute( + (1, batch_size, hidden_dim), + lambda _, b, i: c_init[b, i] if c_init is not None else 0.0, + name="c_init", + ) + + # begin scan computations, first the (batched) hidden-hidden dense + kh = te.reduce_axis((0, proj_dim), name="kh2h") + s_h2h = te.compute( + (scan_len, batch_size, 4, hidden_dim), + lambda t, b, i, j: te.sum(h_state[t - 1, b, kh] * Wh[i * hidden_dim + j, kh], axis=kh), + name="s_h2h", + ) + if Bh is not None: + s_h2h = te.compute( + s_h2h.shape, + lambda t, b, i, j: s_h2h[t, b, i, j] + Bh[i * hidden_dim + j], + name="s_h2h_bias", + tag=tag.INJECTIVE, + ) + + # helper to reverse time if scanning backwards + get_x_t = lambda t: seq_len - t if reverse else t - 1 + + gates = te.compute( + (scan_len, batch_size, 4, hidden_dim), + lambda t, b, i, j: Xi2h[get_x_t(t) * batch_size + b, i * hidden_dim + j] + + s_h2h[t, b, i, j], + name="gates", + tag=tag.INJECTIVE, + ) + + # helper to correctly read each gate dense from the batched output + read_gate = lambda t, b, j, idx: gates[t, b, idx, j] + + gate_shape = (scan_len, batch_size, hidden_dim) + + # compute the activated gates (and do some extra stuff if peephole weights are present) + if p_i is not None and p_f is not None: + i_gate = te.compute( + gate_shape, + lambda t, b, j: f_act( + read_gate(t, b, j, i_gate_idx) + p_i[b, j] * c_state[t - 1, b, j] + ), + name="i_gate_p", + tag=tag.INJECTIVE, + ) + f_gate = te.compute( + gate_shape, + lambda t, b, j: f_act( + read_gate(t, b, j, f_gate_idx) + p_f[b, j] * c_state[t - 1, b, j] + ), + name="f_gate_p", + tag=tag.INJECTIVE, + ) + else: + i_gate = te.compute( + gate_shape, + lambda *i: f_act(read_gate(*i, i_gate_idx)), + name="i_gate", + tag=tag.INJECTIVE, + ) + f_gate = te.compute( + gate_shape, + lambda *i: f_act(read_gate(*i, f_gate_idx)), + name="f_gate", + tag=tag.INJECTIVE, + ) + + g_gate = te.compute( + gate_shape, lambda *i: g_act(read_gate(*i, g_gate_idx)), name="g_gate", tag=tag.INJECTIVE + ) + + next_c = te.compute( + gate_shape, + lambda t, b, j: f_gate[t, b, j] * c_state[t - 1, b, j] + i_gate[t, b, j] * g_gate[t, b, j], + name="next_c", + ) + + if p_o is not None: + o_gate = te.compute( + gate_shape, + lambda t, b, j: f_act(read_gate(t, b, j, o_gate_idx) + p_o[b, j] * next_c[t, b, j]), + name="o_gate_p", + tag=tag.INJECTIVE, + ) + else: + o_gate = te.compute( + gate_shape, + lambda *i: f_act(read_gate(*i, o_gate_idx)), + name="o_gate", + tag=tag.INJECTIVE, + ) + + next_h = te.compute(gate_shape, lambda *i: o_gate(*i) * h_act(next_c(*i)), name="next_h") + + # project hidden state back to proj_dim if projection matrix is present + if proj is not None: + kr = te.reduce_axis((0, hidden_dim), name="kh2p") + next_h = te.compute( + (scan_len, batch_size, proj_dim), + lambda t, b, j: te.sum(next_h[t, b, kr] * proj[j, kr], axis=kr), + name="next_h_proj", + ) + + scan_h, scan_c = te.scan( + [h_init, c_init], [next_h, next_c], [h_state, c_state], name="lstm_scan" + ) + + # drop the initial values, TODO(@altanh): is there a better way? + scan_h = te.compute( + (seq_len, batch_size, proj_dim), lambda t, b, j: scan_h[t + 1, b, j], name="hidden_states" + ) + scan_c = te.compute( + (seq_len, batch_size, hidden_dim), lambda t, b, j: scan_c[t + 1, b, j], name="cell_states" + ) + + return scan_h, scan_c diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 21ddf6fc55361..2f091cba10b7d 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -76,3 +76,4 @@ from .dense import dense from .searchsorted import searchsorted_ref from .conv2d_backcward_weight_python import conv2d_backward_weight_python +from .lstm_python import lstm_python diff --git a/python/tvm/topi/testing/lstm_python.py b/python/tvm/topi/testing/lstm_python.py new file mode 100644 index 0000000000000..ef1bce33658bc --- /dev/null +++ b/python/tvm/topi/testing/lstm_python.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""LSTM reference implementation using numpy.""" +import numpy as np + + +def lstm_python( + Xs: np.array, + Wi: np.array, + Wh: np.array, + Bi: np.array = None, + Bh: np.array = None, + h_init: np.array = None, + c_init: np.array = None, + proj: np.array = None, + p_i: np.array = None, + p_f: np.array = None, + p_o: np.array = None, + f_act: str = "sigmoid", + g_act: str = "tanh", + h_act: str = "tanh", + reverse: bool = False, + weight_layout: str = "IFGO", +): + """LSTM reference implementation using numpy + + Parameters + ---------- + Xs : np.array + (seq_length, batch_size, in_dim) + Wi : np.array + (4 * hidden_dim, in_dim) + Wh : np.array + (4 * hidden_dim, out_dim) where out_dim = proj_dim if proj_dim > 0, else hidden_dim + Bi : np.array, optional + (4 * hidden_dim,), by default None + Bh : np.array, optional + (4 * hidden_dim,), by default None + h_init : np.array, optional + (batch_size, out_dim), by default None + c_init : np.array, optional + (batch_size, hidden_dim), by default None + proj : np.array, optional + (proj_dim, hidden_dim), by default None + p_i, p_f, p_o: np.array, optional + (batch_size, hidden_dim), by default None + f_act, g_act, h_act: str, optional + activations, by default "sigmoid", "tanh", "tanh" + reverse : bool, optional + process Xs in reverse, by default False + weight_layout : str, optional + Packed layout for weights and biases, by default "IFGO" + """ + i_gate_idx = weight_layout.find("I") + f_gate_idx = weight_layout.find("F") + g_gate_idx = weight_layout.find("G") + o_gate_idx = weight_layout.find("O") + + str2act = {"sigmoid": lambda x: 1 / (1 + np.exp(-x)), "tanh": np.tanh} + + f_act = str2act[f_act] + g_act = str2act[g_act] + h_act = str2act[h_act] + + S, B, F = Xs.shape + H = Wi.shape[0] // 4 + O = Wh.shape[1] + + # make life a bit easier + Wi = np.reshape(Wi, (4, H, F)) + Wh = np.reshape(Wh, (4, H, O)) + if Bi is not None: + Bi = np.reshape(Bi, (4, H)) + if Bh is not None: + Bh = np.reshape(Bh, (4, H)) + + h0 = h_init if h_init is not None else np.zeros((B, O), "float32") + c0 = c_init if c_init is not None else np.zeros((B, H), "float32") + + hs = [h0] + cs = [c0] + + for t in range(S): + x = Xs[S - t - 1 if reverse else t] + xh = [np.matmul(x, Wi[g].T) for g in range(4)] + if Bi is not None: + xh = [xh[g] + Bi[g] for g in range(4)] + + hh = [np.matmul(hs[t], Wh[g].T) for g in range(4)] + if Bh is not None: + hh = [hh[g] + Bh[g] for g in range(4)] + + sums = [xh[g] + hh[g] for g in range(4)] + + if p_i is not None and p_f is not None: + i_gate = f_act(sums[i_gate_idx] + p_i * cs[t]) + f_gate = f_act(sums[f_gate_idx] + p_f * cs[t]) + else: + i_gate = f_act(sums[i_gate_idx]) + f_gate = f_act(sums[f_gate_idx]) + + g_gate = g_act(sums[g_gate_idx]) + + next_c = f_gate * cs[t] + i_gate * g_gate + + if p_o is not None: + o_gate = f_act(sums[o_gate_idx] + p_o * next_c) + else: + o_gate = f_act(sums[o_gate_idx]) + + next_h = o_gate * h_act(next_c) + + if proj is not None: + next_h = np.matmul(next_h, proj.T) + + hs.append(next_h) + cs.append(next_c) + + return np.stack(hs[1:], axis=0), np.stack(cs[1:], axis=0) diff --git a/tests/python/topi/python/test_topi_lstm.py b/tests/python/topi/python/test_topi_lstm.py new file mode 100644 index 0000000000000..08ed5d73523d0 --- /dev/null +++ b/tests/python/topi/python/test_topi_lstm.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Test code for LSTM.""" +import numpy as np +from rsa import verify +import tvm +from tvm import te, topi +import tvm.testing +import tvm.topi.testing + + +def verify_lstm( + target, + dev, + seq_len, + batch_size, + in_dim, + hidden_dim, + proj_dim=0, + bias=True, + zero_init=True, + peephole=False, + reverse=False, + weight_layout="IFGO", +): + out_dim = proj_dim if proj_dim > 0 else hidden_dim + + def rand(*shape): + sqrt_k = np.sqrt(1 / hidden_dim) + return np.random.uniform(-sqrt_k, sqrt_k, size=shape).astype("float32") + + def get_ref_data(): + Xs = np.random.normal(size=(seq_len, batch_size, in_dim)).astype("float32") + Wi = rand(4 * hidden_dim, in_dim) + Wh = rand(4 * hidden_dim, out_dim) + Bi = None + Bh = None + h0 = None + c0 = None + proj = None + p_i = None + p_f = None + p_o = None + + if bias: + Bi = rand(4 * hidden_dim) + Bh = rand(4 * hidden_dim) + + if not zero_init: + h0 = np.random.normal(size=(batch_size, out_dim)).astype("float32") + c0 = np.random.normal(size=(batch_size, hidden_dim)).astype("float32") + + if proj_dim > 0: + proj = rand(proj_dim, hidden_dim) + + if peephole: + p_i, p_f, p_o = [rand(batch_size, hidden_dim) for _ in range(3)] + + hs, cs = tvm.topi.testing.lstm_python( + Xs, + Wi, + Wh, + Bi=Bi, + Bh=Bh, + h_init=h0, + c_init=c0, + proj=proj, + p_i=p_i, + p_f=p_f, + p_o=p_o, + reverse=reverse, + weight_layout=weight_layout, + ) + + return [Xs, Wi, Wh, Bi, Bh, h0, c0, proj, p_i, p_f, p_o], [hs, cs] + + args_np, (hs_np, cs_np) = get_ref_data() + + args = [te.placeholder(a.shape, "float32") if a is not None else a for a in args_np] + real_args = [a for a in args if a is not None] + + hs, cs = topi.nn.lstm(*args, reverse=reverse, weight_layout=weight_layout) + with tvm.target.Target(target): + sch = topi.generic.schedule_lstm([hs, cs]) + func = tvm.build(sch, real_args + [hs, cs], target=target) + + args_nd = [tvm.nd.array(a, dev) for a in args_np if a is not None] + hs_nd = tvm.nd.array(np.zeros((seq_len, batch_size, out_dim), "float32"), dev) + cs_nd = tvm.nd.array(np.zeros((seq_len, batch_size, hidden_dim), "float32"), dev) + func(*args_nd, hs_nd, cs_nd) + + tvm.testing.assert_allclose(hs_nd.numpy(), hs_np, rtol=1e-4) + tvm.testing.assert_allclose(cs_nd.numpy(), cs_np, rtol=1e-4) + + +def test_lstm(): + verify_lstm( + "llvm", + tvm.cpu(0), + 1, + 1, + 1, + 1, + 0, + True, + True, + False, + False, + "IFGO", + ) + + verify_lstm( + "llvm", + tvm.cpu(0), + 8, + 4, + 8, + 16, + 0, + True, + False, + False, + False, + "IFGO", + ) + + +def test_lstm_proj(): + verify_lstm("llvm", tvm.cpu(0), 8, 4, 16, 32, 8, True, True, False, False, "IFGO") + + +def test_lstm_peephole(): + verify_lstm("llvm", tvm.cpu(0), 8, 4, 16, 32, 0, True, True, True, False, "IFGO") + + +def test_lstm_reverse(): + verify_lstm("llvm", tvm.cpu(0), 8, 4, 16, 32, 0, True, True, False, True, "IFGO") + + +def test_lstm_weight_layout_iofg(): + # IOFG is used by ONNX, while IFGO is used by PyTorch + verify_lstm("llvm", tvm.cpu(0), 8, 4, 16, 32, 0, True, True, False, False, "IOFG") + + +def test_lstm_assorted(): + verify_lstm("llvm", tvm.cpu(0), 8, 4, 16, 32, 16, True, False, True, True, "OIGF") From 12440895e4baad1de494f0a3876edee3e1df06ee Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 7 Jun 2022 11:08:32 -0700 Subject: [PATCH 058/181] [MetaSchedule] Add Testing Script with ONNX Support (#11587) This PR introduces 2 tuning script for meta schedule and auto scheduler tuning support with onnx files. Now we can easily introduce onnx models benchmarking with command line scripts. Sample tuning call looks similar to the following script For Meta Schedule ONNX tuning: ``` python3 -m tvm.meta_schedule.testing.tune_onnx_meta_schedule \ --model-name "$MODEL_NAME" \ --onnx-path "$ONNX_PATH" \ --input-shape "$INPUT_SHAPE" \ --target "$TARGET" \ --num-trials $NUM_TRIALS \ --rpc-host $RPC_HOST \ --rpc-port $RPC_PORT \ --rpc-key $RPC_KEY \ --rpc-workers $RPC_WORKERS \ --work-dir $WORK_DIR \ |& tee "$WORK_DIR/$MODEL_NAME.log" ``` For AutoScheduler ONNX tuning: ``` python3 -m tvm.meta_schedule.testing.tune_onnx_auto_scheduler \ --model-name "$MODEL_NAME" \ --onnx-path "$ONNX_PATH" \ --input-shape "$INPUT_SHAPE" \ --target "$TARGET" \ --num-trials $NUM_TRIALS \ --rpc-host $RPC_HOST \ --rpc-port $RPC_PORT \ --rpc-key $RPC_KEY \ --rpc-workers $RPC_WORKERS \ --log-dir $WORK_DIR \ |& tee "$WORK_DIR/$MODEL_NAME.log" ``` --- .../testing/tune_onnx_auto_scheduler.py | 238 ++++++++++++++++++ .../testing/tune_onnx_meta_schedule.py | 199 +++++++++++++++ .../testing/tune_relay_auto_scheduler.py | 4 +- 3 files changed, 439 insertions(+), 2 deletions(-) create mode 100644 python/tvm/meta_schedule/testing/tune_onnx_auto_scheduler.py create mode 100644 python/tvm/meta_schedule/testing/tune_onnx_meta_schedule.py diff --git a/python/tvm/meta_schedule/testing/tune_onnx_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_onnx_auto_scheduler.py new file mode 100644 index 0000000000000..e916f5ace3393 --- /dev/null +++ b/python/tvm/meta_schedule/testing/tune_onnx_auto_scheduler.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import json +import os + +import numpy as np # type: ignore +import onnx # type: ignore +import tvm +from tvm.relay.frontend import from_onnx +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm import relay +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--model-name", + type=str, + required=True, + ) + args.add_argument( + "--onnx-path", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + help='example: `[{"name": "input1", "dtype": "int64", "shape": [1, 1, 8]}]', + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--rpc-workers", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=3600, + ) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.work_dir, f"{ARGS.model_name}.json") + + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + number=3, + repeat=1, + min_repeat_ms=100, # TODO + enable_cpu_cache_flush=False, # TODO + ) + + if ARGS.target.kind.name == "llvm": + hardware_params = auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ) + elif ARGS.target.kind.name == "cuda": + hardware_params = auto_scheduler.HardwareParams( + num_cores=-1, + vector_unit_bytes=16, + cache_line_bytes=64, + max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]), + max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + # The value `max_local_memory_per_block` is not used in AutoScheduler, + # but is required by the API. + max_local_memory_per_block=12345678, + max_vthread_extent=8, + warp_size=32, + ) + else: + raise NotImplementedError(f"Unsupported target {ARGS.target}") + + print(f"Workload: {ARGS.model_name}") + onnx_model = onnx.load(ARGS.onnx_path) + shape_dict = {} + for item in ARGS.input_shape: + print(f" input_name: {item['name']}") + print(f" input_shape: {item['shape']}") + print(f" input_dtype: {item['dtype']}") + shape_dict[item["name"]] = item["shape"] + mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True) + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], + params, + target=ARGS.target, + hardware_params=hardware_params, + ) + for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): + print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====") + print(task.compute_dag) + + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tuner.tune( + auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + runner=runner, + measure_callbacks=[ + auto_scheduler.RecordToFile(log_file), + ], + ) + ) + + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + ): + lib = relay.build( + mod, + target=ARGS.target, + params=params, + ) + graph, rt_mod, params = lib.graph_json, lib.lib, lib.params + input_data = {} + for item in ARGS.input_shape: + input_name, input_shape, input_dtype = item["name"], item["shape"], item["dtype"] + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) + + def f_timer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.graph_executor import GraphModule + + # pylint: enable=import-outside-toplevel + + mod = GraphModule(rt_mod["default"](dev)) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) + ftimer = mod.module.time_evaluator( + "run", + dev, + min_repeat_ms=500, + repeat=3, + ) + results = list(np.array(ftimer().results) * 1000.0) # type: ignore + print("Running time in time_evaluator: ", results) + + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_timer, + ) + + def f_per_layer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.debugger.debug_executor import create + + # pylint: enable=import-outside-toplevel + mod = create(graph, rt_mod, dev) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) + graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] + graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) + print("|graph_nodes| = ", len(graph_nodes)) + print("|graph_time| = ", len(graph_time)) + graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)} + for k, v in graph_nodes_time.items(): + print(f"{k} : {v:.3f}") + + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=rt_mod, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_per_layer, + ) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/tune_onnx_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_onnx_meta_schedule.py new file mode 100644 index 0000000000000..f5c7d1cde80b4 --- /dev/null +++ b/python/tvm/meta_schedule/testing/tune_onnx_meta_schedule.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import json +import logging +import numpy as np # type: ignore +import onnx # type: ignore +import tvm +from tvm.relay.frontend import from_onnx +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--model-name", + type=str, + required=True, + ) + args.add_argument( + "--onnx-path", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + help='example: `[{"name": "input1", "dtype": "int64", "shape": [1, 1, 8]}]', + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--rpc-workers", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=3600, + ) + return parsed + + +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) +ARGS = _parse_args() + + +def main(): + print(f"Workload: {ARGS.model_name}") + onnx_model = onnx.load(ARGS.onnx_path) + shape_dict = {} + for item in ARGS.input_shape: + print(f" input_name: {item['name']}") + print(f" input_shape: {item['shape']}") + print(f" input_dtype: {item['dtype']}") + shape_dict[item["name"]] = item["shape"] + mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True) + alloc_repeat = 1 + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + alloc_repeat=alloc_repeat, + max_workers=ARGS.rpc_workers, + ) + lib = ms.tune_relay( + mod=mod, + target=ARGS.target, + config=ms.TuneConfig( + strategy="evolutionary", + num_trials_per_iter=64, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, + ), + runner=runner, # type: ignore + work_dir=ARGS.work_dir, + params=params, + ) + graph, rt_mod, params = lib.graph_json, lib.lib, lib.params + input_data = {} + for item in ARGS.input_shape: + input_name, input_shape, input_dtype = item["name"], item["shape"], item["dtype"] + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) + + def f_timer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.graph_executor import GraphModule + + # pylint: enable=import-outside-toplevel + + mod = GraphModule(rt_mod["default"](dev)) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) + ftimer = mod.module.time_evaluator( + "run", + dev, + min_repeat_ms=500, + repeat=3, + ) + results = list(np.array(ftimer().results) * 1000.0) # type: ignore + print("Running time in time_evaluator: ", results) + + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_timer, + ) + + def f_per_layer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.debugger.debug_executor import create + + # pylint: enable=import-outside-toplevel + mod = create(graph, rt_mod, dev) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) + graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] + graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) + print("|graph_nodes| = ", len(graph_nodes)) + print("|graph_time| = ", len(graph_time)) + graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)} + for k, v in graph_nodes_time.items(): + print(f"{k} : {v:.3f}") + + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=rt_mod, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_per_layer, + ) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py index abac49c50c6ee..ff4f9313470c9 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -71,7 +71,7 @@ def _parse_args(): required=True, ) args.add_argument( - "--log-dir", + "--work-dir", type=str, required=True, ) @@ -96,7 +96,7 @@ def _parse_args(): def main(): - log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json") runner = auto_scheduler.RPCRunner( key=ARGS.rpc_key, From 81702192b49ddb37ce3e179eec3e88f3726acec1 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 7 Jun 2022 13:38:03 -0500 Subject: [PATCH 059/181] [MetaSchedule] Resolve dependencies between header files (#11604) * [MetaSchedule] Resolve dependencies between header files After PR11590 TVM stopped compiling with clang-14 and libc++. The problems were caused by incomplete types used in contexts where complete types were required. To resolve this, some code had to be moved into .cc files. Also the MeasureCandidate classes needed to be added to their own include files (or otherwise there would be a circular dependency between headers). All headers from the meta_schedule directory were updated to include all their dependencies (forward declarations were left where appropriate). * Fix a typo: PySpaceGeneratorCode -> PySpaceGeneratorNode --- .../tvm/meta_schedule/apply_history_best.h | 9 ++- include/tvm/meta_schedule/arg_info.h | 3 + include/tvm/meta_schedule/builder.h | 8 +++ include/tvm/meta_schedule/cost_model.h | 34 ++++----- include/tvm/meta_schedule/database.h | 7 ++ include/tvm/meta_schedule/extracted_task.h | 7 +- include/tvm/meta_schedule/feature_extractor.h | 13 ++-- include/tvm/meta_schedule/measure_callback.h | 11 +-- include/tvm/meta_schedule/measure_candidate.h | 67 ++++++++++++++++++ include/tvm/meta_schedule/mutator.h | 18 +++-- include/tvm/meta_schedule/postproc.h | 15 ++-- include/tvm/meta_schedule/runner.h | 6 ++ include/tvm/meta_schedule/schedule_rule.h | 20 +++--- include/tvm/meta_schedule/search_strategy.h | 69 ++++--------------- include/tvm/meta_schedule/space_generator.h | 21 +++--- include/tvm/meta_schedule/task_scheduler.h | 47 +++---------- include/tvm/meta_schedule/tune_context.h | 8 +++ src/meta_schedule/cost_model/cost_model.cc | 24 +++++++ .../feature_extractor/feature_extractor.cc | 6 ++ .../measure_callback/measure_callback.cc | 9 +++ src/meta_schedule/mutator/mutator.cc | 12 ++++ src/meta_schedule/postproc/postproc.cc | 11 +++ .../schedule_rule/schedule_rule.cc | 12 ++++ .../search_strategy/search_strategy.cc | 27 +++++++- .../space_generator/space_generator.cc | 12 ++++ .../task_scheduler/task_scheduler.cc | 37 ++++++++++ 26 files changed, 344 insertions(+), 169 deletions(-) create mode 100644 include/tvm/meta_schedule/measure_candidate.h diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h index b5504a8ee0f8c..5b1816cef41ff 100644 --- a/include/tvm/meta_schedule/apply_history_best.h +++ b/include/tvm/meta_schedule/apply_history_best.h @@ -19,7 +19,14 @@ #ifndef TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ #define TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ +#include #include +#include +#include +#include +#include +#include +#include #include namespace tvm { @@ -36,7 +43,7 @@ class ApplyHistoryBestNode : public runtime::Object { /*! \brief The logging function to be used */ PackedFunc logging_func; - void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("database", &database); } /*! * \brief Query the best entry from the database * \param task_name The name of the task to be queried diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h index 08553a001374e..c7dd3c7f65385 100644 --- a/include/tvm/meta_schedule/arg_info.h +++ b/include/tvm/meta_schedule/arg_info.h @@ -20,7 +20,10 @@ #define TVM_META_SCHEDULE_ARG_INFO_H_ #include +#include #include +#include +#include #include namespace tvm { diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 2b809459155ec..e41dc900a00da 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -20,6 +20,14 @@ #define TVM_META_SCHEDULE_BUILDER_H_ #include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace tvm { diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 6fadc2fb9c137..91d19c430b1fe 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -20,7 +20,15 @@ #ifndef TVM_META_SCHEDULE_COST_MODEL_H_ #define TVM_META_SCHEDULE_COST_MODEL_H_ -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include @@ -126,28 +134,12 @@ class PyCostModelNode : public CostModelNode { // `f_as_string` is not visited } - void Load(const String& path) { - ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; - f_load(path); - } - - void Save(const String& path) { - ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; - f_save(path); - } + void Load(const String& path); + void Save(const String& path); void Update(const TuneContext& context, const Array& candidates, - const Array& results) { - ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; - f_update(context, candidates, results); - } - + const Array& results); std::vector Predict(const TuneContext& context, - const Array& candidates) { - ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; - std::vector result(candidates.size(), 0.0); - f_predict(context, candidates, result.data()); - return result; - } + const Array& candidates); static constexpr const char* _type_key = "meta_schedule.PyCostModel"; TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index f07d8e1366441..1353dec3eda3f 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -19,7 +19,14 @@ #ifndef TVM_META_SCHEDULE_DATABASE_H_ #define TVM_META_SCHEDULE_DATABASE_H_ +#include +#include #include +#include +#include +#include +#include +#include #include #include diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index c6613427fd5b6..898b974d87726 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -19,6 +19,11 @@ #ifndef TVM_META_SCHEDULE_EXTRACTED_TASK_H_ #define TVM_META_SCHEDULE_EXTRACTED_TASK_H_ +#include +#include +#include +#include +#include #include namespace tvm { @@ -38,7 +43,7 @@ class ExtractedTaskNode : public runtime::Object { /*! \brief Weight of the task */ int weight; - void VisitAttrs(AttrVisitor* v) { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task_name", &task_name); v->Visit("mod", &mod); v->Visit("target", &target); diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index c2ca2beb9b686..02e9f26b2a600 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -20,7 +20,13 @@ #ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ #define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ -#include +#include +#include +#include +#include +#include +#include +#include namespace tvm { namespace meta_schedule { @@ -76,10 +82,7 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { } Array ExtractFrom(const TuneContext& context, - const Array& candidates) { - ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; - return f_extract_from(context, candidates); - } + const Array& candidates) final; static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index e9abb123012ab..151582d4c9ce6 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -21,9 +21,15 @@ #define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ #include +#include #include #include #include +#include +#include +#include +#include +#include namespace tvm { namespace meta_schedule { @@ -94,10 +100,7 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { int task_id, // const Array& measure_candidates, // const Array& builds, // - const Array& results) final { - ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; - return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results); - } + const Array& results); static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h new file mode 100644 index 0000000000000..f7257b56d2067 --- /dev/null +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_ +#define TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The schedule (with input shapes) to be measured. */ +class MeasureCandidateNode : public runtime::Object { + public: + /*! \brief The schedule for measurement. */ + tir::Schedule sch; + /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("sch", &sch); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); +}; + +/*! + * \brief Managed reference to MeasureCandidateNode. + * \sa MeasureCandidateNode + */ +class MeasureCandidate : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of MeasureCandidate. + * \param sch The schedule for measurement. + * \param args_info The argument information, e.g., (shape, dtype) for tensors. + */ + TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_ diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index d80fa70eee8a2..566cc82e9716d 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -20,7 +20,13 @@ #ifndef TVM_META_SCHEDULE_MUTATOR_H_ #define TVM_META_SCHEDULE_MUTATOR_H_ +#include +#include +#include +#include +#include #include +#include namespace tvm { namespace meta_schedule { @@ -89,17 +95,9 @@ class PyMutatorNode : public MutatorNode { // `f_as_string` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(f_initialize_with_tune_context != nullptr) - << "PyMutator's InitializeWithTuneContext method not implemented!"; - this->f_initialize_with_tune_context(context); - } - + void InitializeWithTuneContext(const TuneContext& context) final; Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) final { - ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; - return this->f_apply(trace, *rand_state); - } + support::LinearCongruentialEngine::TRandState* rand_state) final; static constexpr const char* _type_key = "meta_schedule.PyMutator"; TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 195d558550170..738e726aa146b 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -20,6 +20,9 @@ #ifndef TVM_META_SCHEDULE_POSTPROC_H_ #define TVM_META_SCHEDULE_POSTPROC_H_ +#include +#include +#include #include namespace tvm { @@ -88,16 +91,8 @@ class PyPostprocNode : public PostprocNode { // `f_as_string` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(f_initialize_with_tune_context != nullptr) - << "PyPostproc's InitializeWithTuneContext method not implemented!"; - this->f_initialize_with_tune_context(context); - } - - bool Apply(const tir::Schedule& sch) final { - ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; - return this->f_apply(sch); - } + void InitializeWithTuneContext(const TuneContext& context) final; + bool Apply(const tir::Schedule& sch) final; static constexpr const char* _type_key = "meta_schedule.PyPostproc"; TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 61023c8e2db05..c095728369312 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -21,6 +21,12 @@ #include #include +#include +#include +#include +#include +#include +#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index b39c72e24db8e..7e0e5bda57b60 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -20,6 +20,14 @@ #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ #define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ +#include +#include +#include +#include +#include +#include +#include +#include #include namespace tvm { @@ -90,16 +98,8 @@ class PyScheduleRuleNode : public ScheduleRuleNode { // `f_as_string` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(f_initialize_with_tune_context != nullptr) - << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; - this->f_initialize_with_tune_context(context); - } - - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { - ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; - return this->f_apply(sch, block); - } + void InitializeWithTuneContext(const TuneContext& context) final; + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 139de7c99d042..baae22f0d98ec 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -20,7 +20,15 @@ #define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ #include +#include +#include +#include #include +#include +#include +#include +#include +#include #include namespace tvm { @@ -28,40 +36,6 @@ namespace meta_schedule { // Forward declaration class TuneContext; -class CostModel; -class Database; - -/*! \brief The schedule (with input shapes) to be measured. */ -class MeasureCandidateNode : public runtime::Object { - public: - /*! \brief The schedule for measurement. */ - tir::Schedule sch; - /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ - Array args_info; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("sch", &sch); - v->Visit("args_info", &args_info); - } - - static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; - TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); -}; - -/*! - * \brief Managed reference to MeasureCandidateNode. - * \sa MeasureCandidateNode - */ -class MeasureCandidate : public runtime::ObjectRef { - public: - /*! - * \brief Constructor of MeasureCandidate. - * \param sch The schedule for measurement. - * \param args_info The argument information, e.g., (shape, dtype) for tensors. - */ - TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); -}; /*! * \brief The search strategy for measure candidates generation. @@ -198,33 +172,14 @@ class PySearchStrategyNode : public SearchStrategyNode { // `f_notify_runner_results` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(f_initialize_with_tune_context != nullptr) - << "PySearchStrategy's InitializeWithTuneContext method not implemented!"; - this->f_initialize_with_tune_context(context); - } - + void InitializeWithTuneContext(const TuneContext& context) final; void PreTuning(const Array& design_spaces, const Optional& database, const Optional& cost_model) final; - - void PostTuning() final { - ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; - this->f_post_tuning(); - } - - Optional> GenerateMeasureCandidates() final { - ICHECK(f_generate_measure_candidates != nullptr) - << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; - return this->f_generate_measure_candidates(); - } - + void PostTuning() final; + Optional> GenerateMeasureCandidates() final; void NotifyRunnerResults(const TuneContext& context, const Array& measure_candidates, - const Array& results) final { - ICHECK(f_notify_runner_results != nullptr) - << "PySearchStrategy's NotifyRunnerResults method not implemented!"; - this->f_notify_runner_results(context, measure_candidates, results); - } + const Array& results); static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index bad9ae0f6c6e9..f7d6cac31cab6 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -20,6 +20,10 @@ #define TVM_META_SCHEDULE_SPACE_GENERATOR_H_ #include +#include +#include +#include +#include #include namespace tvm { @@ -64,7 +68,7 @@ class TuneContext; │ └─── Runner Future ◄────┘ │ └─────────────────────────────────────────────────────────────────────┘ */ -class SpaceGeneratorNode : public Object { +class SpaceGeneratorNode : public runtime::Object { public: /*! \brief Default destructor */ virtual ~SpaceGeneratorNode() = default; @@ -112,17 +116,8 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { // `f_generate_design_space` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(f_initialize_with_tune_context != nullptr) - << "PySpaceGenerator's InitializeWithTuneContext method not implemented!"; - f_initialize_with_tune_context(context); - } - - Array GenerateDesignSpace(const IRModule& mod) final { - ICHECK(f_generate_design_space != nullptr) - << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; - return f_generate_design_space(mod); - } + void InitializeWithTuneContext(const TuneContext& context) final; + Array GenerateDesignSpace(const IRModule& mod) final; static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); @@ -132,7 +127,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { * \brief Managed reference to SpaceGeneratorNode. * \sa SpaceGeneratorNode */ -class SpaceGenerator : public ObjectRef { +class SpaceGenerator : public runtime::ObjectRef { protected: SpaceGenerator() = default; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 5953a2c3e42b1..385816e790e29 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -25,6 +25,12 @@ #include #include #include +#include +#include +#include +#include +#include +#include namespace tvm { namespace meta_schedule { @@ -181,42 +187,11 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { // `f_next_task_id` is not visited } - void Tune() final { - if (f_tune == nullptr) { - TaskSchedulerNode::Tune(); - } else { - f_tune(); - } - } - - void InitializeTask(int task_id) final { - if (f_initialize_task == nullptr) { - TaskSchedulerNode::InitializeTask(task_id); - } else { - f_initialize_task(task_id); - } - } - - void TouchTask(int task_id) final { - if (f_touch_task == nullptr) { - return TaskSchedulerNode::TouchTask(task_id); - } else { - return f_touch_task(task_id); - } - } - - Array JoinRunningTask(int task_id) final { - if (f_join_running_task == nullptr) { - return TaskSchedulerNode::JoinRunningTask(task_id); - } else { - return f_join_running_task(task_id); - } - } - - int NextTaskId() final { - ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; - return f_next_task_id(); - } + void Tune() final; + void InitializeTask(int task_id) final; + void TouchTask(int task_id) final; + Array JoinRunningTask(int task_id) final; + int NextTaskId() final; static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler"; TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index d63fb819f3639..ee09099d1a926 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -19,6 +19,7 @@ #ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ +#include #include #include #include @@ -27,6 +28,13 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #include #include diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index c6efb54303360..aabab5d83a1c9 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -21,6 +21,30 @@ namespace tvm { namespace meta_schedule { +void PyCostModelNode::Load(const String& path) { + ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; + f_load(path); +} + +void PyCostModelNode::Save(const String& path) { + ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; + f_save(path); +} + +void PyCostModelNode::Update(const TuneContext& context, const Array& candidates, + const Array& results) { + ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; + f_update(context, candidates, results); +} + +std::vector PyCostModelNode::Predict(const TuneContext& context, + const Array& candidates) { + ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; + std::vector result(candidates.size(), 0.0); + f_predict(context, candidates, result.data()); + return result; +} + CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // PyCostModelNode::FSave f_save, // PyCostModelNode::FUpdate f_update, // diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index 84d22493aaa6d..1ebbb6e2e2339 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -21,6 +21,12 @@ namespace tvm { namespace meta_schedule { +Array PyFeatureExtractorNode::ExtractFrom( + const TuneContext& context, const Array& candidates) { + ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; + return f_extract_from(context, candidates); +} + FeatureExtractor FeatureExtractor::PyFeatureExtractor( PyFeatureExtractorNode::FExtractFrom f_extract_from, // PyFeatureExtractorNode::FAsString f_as_string) { diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 733d118c735d3..c7851a6fadf62 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -21,6 +21,15 @@ namespace tvm { namespace meta_schedule { +void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) { + ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; + return f_apply(task_scheduler, task_id, measure_candidates, builds, results); +} + MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // PyMeasureCallbackNode::FAsString f_as_string) { ObjectPtr n = make_object(); diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 27383adf84e0e..43b95000c71d4 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -21,6 +21,18 @@ namespace tvm { namespace meta_schedule { +void PyMutatorNode::InitializeWithTuneContext(const TuneContext& context) { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; + f_initialize_with_tune_context(context); +} + +Optional PyMutatorNode::Apply( + const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return f_apply(trace, *rand_state); +} + Mutator Mutator::PyMutator( PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyMutatorNode::FApply f_apply, // diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index ff069e2c68cbd..0f4f1b1192f65 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -21,6 +21,17 @@ namespace tvm { namespace meta_schedule { +void PyPostprocNode::InitializeWithTuneContext(const TuneContext& context) { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; + f_initialize_with_tune_context(context); +} + +bool PyPostprocNode::Apply(const tir::Schedule& sch) { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return f_apply(sch); +} + Postproc Postproc::PyPostproc( PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyPostprocNode::FApply f_apply, // diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index f80f684dafa81..80f8725b0c0d7 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -21,6 +21,18 @@ namespace tvm { namespace meta_schedule { +void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; + f_initialize_with_tune_context(context); +} + +Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block) { + ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; + return f_apply(sch, block); +} + ScheduleRule ScheduleRule::PyScheduleRule( PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyScheduleRuleNode::FApply f_apply, // diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index a6a1100cebe60..f4c392ca2f1a1 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -28,11 +28,36 @@ MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) data_ = std::move(n); } +void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySearchStrategy's InitializeWithTuneContext method not implemented!"; + f_initialize_with_tune_context(context); +} + void PySearchStrategyNode::PreTuning(const Array& design_spaces, const Optional& database, const Optional& cost_model) { ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; - this->f_pre_tuning(design_spaces, database, cost_model); + f_pre_tuning(design_spaces, database, cost_model); +} + +void PySearchStrategyNode::PostTuning() { + ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; + f_post_tuning(); +} + +Optional> PySearchStrategyNode::GenerateMeasureCandidates() { + ICHECK(f_generate_measure_candidates != nullptr) + << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; + return f_generate_measure_candidates(); +} + +void PySearchStrategyNode::NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) { + ICHECK(f_notify_runner_results != nullptr) + << "PySearchStrategy's NotifyRunnerResults method not implemented!"; + f_notify_runner_results(context, measure_candidates, results); } SearchStrategy SearchStrategy::PySearchStrategy( diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 6df8da2f7aa12..5c5ab6ebbae5b 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -21,6 +21,18 @@ namespace tvm { namespace meta_schedule { +void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySpaceGenerator's InitializeWithTuneContext method not implemented!"; + f_initialize_with_tune_context(context); +} + +Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { + ICHECK(f_generate_design_space != nullptr) + << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; + return f_generate_design_space(mod); +} + SpaceGenerator SpaceGenerator::PySpaceGenerator( PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) { diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 25867fb4f3bbf..5d41f2edfb26f 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -199,6 +199,43 @@ Array TaskSchedulerNode::JoinRunningTask(int task_id) { return results; } +void PyTaskSchedulerNode::Tune() { + if (f_tune == nullptr) { + TaskSchedulerNode::Tune(); + } else { + f_tune(); + } +} + +void PyTaskSchedulerNode::InitializeTask(int task_id) { + if (f_initialize_task == nullptr) { + TaskSchedulerNode::InitializeTask(task_id); + } else { + f_initialize_task(task_id); + } +} + +void PyTaskSchedulerNode::TouchTask(int task_id) { + if (f_touch_task == nullptr) { + return TaskSchedulerNode::TouchTask(task_id); + } else { + return f_touch_task(task_id); + } +} + +Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { + if (f_join_running_task == nullptr) { + return TaskSchedulerNode::JoinRunningTask(task_id); + } else { + return f_join_running_task(task_id); + } +} + +int PyTaskSchedulerNode::NextTaskId() { + ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; + return f_next_task_id(); +} + TaskScheduler TaskScheduler::PyTaskScheduler( Array tasks, // Builder builder, // From d8f57ed7ff6daf585ca56bc2cf9326eca9e73fca Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 7 Jun 2022 11:54:46 -0700 Subject: [PATCH 060/181] [Relay] IndexedGraph improvements in preparation for Collage (#11481) * [Relay] Odd's 'n ends changes to help Collage. - Complete the implementation of WithFields. (Unfortunately they appear to be without unit tests and I continue this tradition...) - InferTypeExpr for InferTypeLocal but return the expression rather than the type. - Remove python binding of InlineComposites since C++ impl was removed some time ago. - Make IndexedGraph more robust as stand-alone datastructure, and avoid unnecessary copies. This will become a fundamental datastructure in Collage rather than just a helper for DFPatternMatcher. - Extend IndexedGraph with a notion of 'basic block' on every dataflow node. Needed by Collage to avoid impossible partitions. * - Revert non IndexedGraph changes. * - Stick to 'Indexed graph' terminology - More tests * - Stick to 'Indexed graph' terminology - More tests * - Remove silly unit test --- src/relay/ir/dataflow_matcher.cc | 90 ++-- src/relay/ir/dataflow_matcher_impl.h | 19 +- src/relay/ir/indexed_graph.cc | 526 ++++++++++++++------ src/relay/ir/indexed_graph.h | 283 +++++++++-- src/relay/op/dyn/tensor/transform.cc | 1 + tests/cpp/relay/ir/indexed_graph_test.cc | 205 ++++++++ tests/python/relay/test_dataflow_pattern.py | 35 +- 7 files changed, 922 insertions(+), 237 deletions(-) create mode 100644 tests/cpp/relay/ir/indexed_graph_test.cc diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 8d7ed163a1975..df896cb690eb2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -36,6 +36,7 @@ namespace relay { // Pattern Matcher bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr); memo_.clear(); matched_nodes_.clear(); return VisitDFPattern(pattern, expr); @@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr if (out) { memo_[pattern].push_back(expr); matched_nodes_.push_back(pattern); + VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr); } else { ClearMap(watermark); } @@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons if (!matches) { return matches; } - VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr); auto attributes = attr_pattern->attrs.as()->dict; if (const auto* op_node = expr.as()) { Op op = GetRef(op_node); @@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as(); - for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { - if (!(call_node && node->ref_ == call_node->op)) { + auto index_node = expr_to_node(expr); + for (auto node : index_node->inputs_) { + if (!(call_node && node->ref() == call_node->op)) { memoize_ = true; - if (VisitDFPattern(op->parent, node->ref_)) { + if (VisitDFPattern(op->parent, node->ref())) { return true; } else { memoize_ = false; - if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { + if (!VisitDFPattern(op->path, node->ref())) { + return false; + } + if (!MatchesPath(op, node->ref())) { return false; } } @@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e // Iteratively ensure that the parent is dominated somewhere by the child or the path bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { std::stack stack; - std::unordered_set visited; + std::unordered_set visited; stack.push(expr); while (!stack.empty()) { Expr current = stack.top(); stack.pop(); - for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { - if (visited.count(node->ref_) == 0) { - if (VisitDFPattern(op->parent, node->ref_)) { + for (auto node : expr_to_node(current)->dominator_children_) { + if (visited.count(node->node_ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref())) { return true; } else { - stack.push(node->ref_); + stack.push(node->ref()); } - visited.insert(node->ref_); + visited.insert(node->node_ref_); } } } @@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } bool MatchPattern(DFPattern pattern, Expr expr) { - return DFPatternMatcher(expr).Match(pattern, expr); + std::unique_ptr> expr_graph = CreateIndexedGraph(expr); + return DFPatternMatcher(expr_graph.get()).Match(pattern, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); @@ -575,7 +581,8 @@ const std::unordered_map& PatternGrouper::GroupMatch pattern_ = pattern; pattern_graph_ = CreateIndexedGraph(pattern_); - auto matcher = DFPatternMatcher(pre); + std::unique_ptr> expr_graph = CreateIndexedGraph(pre); + DFPatternMatcher matcher(expr_graph.get()); matcher_ = &matcher; this->VisitExprs(); return this->groups_; @@ -583,9 +590,9 @@ const std::unordered_map& PatternGrouper::GroupMatch void PatternGrouper::VisitExprs() { std::unordered_set pre_partitioned; - for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + for (PostDfsIndex i = matcher_->size(); i != 0; --i) { + PostDfsIndex index = i - 1; + const auto current = matcher_->index_to_node(index)->ref(); if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped if (auto op = current.as()) { if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { @@ -607,9 +614,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto node_map = matcher_->GetMemo(); // Get fuzzy patterns std::unordered_set fuzzy_matches; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); // Don't treat fuzzy Dominator patterns input variables for partition - if (auto op = node->ref_.as()) { + if (auto op = node->ref().as()) { for (auto fuzzy_op : {op->parent, op->path}) { for (auto match : node_map[fuzzy_op]) { fuzzy_matches.insert(match); @@ -617,12 +625,13 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } // Don't treat Function params or body as input variables for partition - if (node->ref_.as()) { - auto matches = node_map[node->ref_]; + if (node->ref().as()) { + auto matches = node_map[node->ref()]; for (auto match : matches) { - auto graph = CreateIndexedGraph(match.as()->body); - for (auto node : graph.topological_order_) { - fuzzy_matches.insert(node->ref_); + auto sub_graph = CreateIndexedGraph(match.as()->body); + for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) { + auto sub_node = sub_graph->index_to_node(sub_index); + fuzzy_matches.insert(sub_node->ref()); } } } @@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { std::unordered_map inputs; Array params; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); auto make_input = [&](const Expr& input) { if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && - input.as() == nullptr && !EmbedConst(input, node->ref_)) { + input.as() == nullptr && !EmbedConst(input, node->ref())) { inputs[input] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), NullValue()); @@ -648,11 +658,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { var_number++; } }; - auto tuple = node->ref_.as(); - auto call = node->ref_.as(); + auto tuple = node->ref().as(); + auto call = node->ref().as(); if (tuple && !tuple->fields.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->fields) { make_input(input); @@ -660,8 +670,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (call && !call->args.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->args) { make_input(input); @@ -669,8 +679,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (node->inputs_.size() == 0) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { make_input(match); } @@ -708,13 +718,17 @@ void PatternGrouper::CreateGroup(const Expr& expr) { return; } else if (kv.second != body) { // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); + auto node = matcher_->expr_to_node(kv.first); for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid + if (memo.count(output->ref()) == 0) { + // TODO(mbs): This condition used to also include the following test, which since + // the dominators relation is used back-to-front was always vacuously true. So the + // code is just rejecting the match if a strictly internal node happened to connect + // to an outside node. + ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); + // Exit because nodes in this pattern's body are used outside the pattern, fusing it + // would be invalid return; } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index d993d4720e4ed..f04190f72e40b 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -27,7 +27,9 @@ #include #include #include +#include +#include #include #include #include @@ -39,10 +41,20 @@ namespace relay { class DFPatternMatcher : public DFPatternFunctor { public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + explicit DFPatternMatcher(const IndexedGraph* expr_graph) : expr_graph_(expr_graph) {} bool Match(const DFPattern& pattern, const Expr& expr); Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; + + const IndexedGraph::Node* expr_to_node(const Expr& expr) const { + return expr_graph_->item_to_node(expr); + } + const IndexedGraph::Node* index_to_node(size_t index) const { + return expr_graph_->index_to_node(index); + } + size_t size() const { return expr_graph_->size(); } + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& memo() const { + return memo_; + } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -67,6 +79,7 @@ class DFPatternMatcher : public DFPatternFunctor* expr_graph_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; std::vector matched_nodes_; bool memoize_ = true; @@ -131,7 +144,7 @@ class PatternGrouper { std::unordered_map groups_; std::unordered_map gid_assignments_; DFPatternMatcher* matcher_ = nullptr; - IndexedGraph pattern_graph_; + std::unique_ptr> pattern_graph_; int gid_ = 0; int graph_number_ = 0; }; diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 4efe57b491db0..f39ff4850eae1 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -19,195 +19,393 @@ /*! * \file src/relay/ir/indexed_graph.cc - * \brief Utilties for Creating Indexed Graphs. + * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) + * pattern. */ #include "indexed_graph.h" #include #include #include -#include +#include + +#include namespace tvm { namespace relay { -// IndexedGraph +std::string RefToSummary(const Expr& expr) { + class Visitor : public ExprFunctor { + std::string VisitExpr_(const VarNode* op) final { return "%" + op->name_hint(); } + std::string VisitExpr_(const GlobalVarNode* op) final { return "@" + op->name_hint; } + std::string VisitExpr_(const ConstantNode* op) final { return "const"; } + std::string VisitExpr_(const TupleNode* op) final { + return "tuple(" + std::to_string(op->fields.size()) + ")"; + } + std::string VisitExpr_(const FunctionNode* op) final { return "fn"; } + std::string VisitExpr_(const CallNode* op) final { + return VisitExpr(op->op) + "(" + std::to_string(op->args.size()) + ")"; + } + std::string VisitExpr_(const LetNode* op) final { return "let"; } + std::string VisitExpr_(const IfNode* op) final { return "if"; } + std::string VisitExpr_(const OpNode* op) final { return op->name; } + std::string VisitExpr_(const TupleGetItemNode* op) final { + return "." + std::to_string(op->index); + } + std::string VisitExpr_(const RefCreateNode* op) final { return "ref_create"; } + std::string VisitExpr_(const RefReadNode* op) final { return "ref_read"; } + std::string VisitExpr_(const RefWriteNode* op) final { return "ref_write"; } + std::string VisitExpr_(const ConstructorNode* op) final { return "ctor"; } + std::string VisitExpr_(const MatchNode* op) final { return "match"; } + }; + return Visitor().VisitExpr(expr); +} + +std::string RefToSummary(const DFPattern& pattern) { + // TODO(mbs): Implement as debugging requires. + return ""; +} -IndexedGraph CreateIndexedGraph(const Expr& expr) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr) { + /*! + * \brief Adds indexed graph nodes in post-dfs order, and discovers which let-bound vars are to + * recursive functions. + */ class Creator : public MixedModeVisitor { public: - IndexedGraph CreateGraph(const Expr& expr) { + std::pair>, + std::unique_ptr>> + CreateGraph(const Expr& expr) { VisitExpr(expr); - graph_.node_map_[expr]->is_external_ = true; - return std::move(graph_); + // Last visited node is implicitly used 'externally'. + graph_->item_to_node(expr)->is_external_ = true; + return {std::move(graph_), std::move(rec_calls_)}; } protected: using MixedModeVisitor::VisitExpr_; + // By the default the MixedModeVisitor will place + // - callee and arguments before a call + // - tuple fields before a tuple + // - tuple before a tuple projection void VisitLeaf(const Expr& expr) override { + if (const auto* var_node = expr.as()) { + if (var_node == current_let_bound_var_) { + // Don't visit occurrences of let-rec bound vars in the recursive function body. + // Instead, wait for them to be visited at call sites outside of the function. + VLOG(1) << "Ignore let-rec var '" << var_node->name_hint() << "'"; + return; + } + } + MixedModeVisitor::VisitLeaf(expr); - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(expr); + + if (const auto* call_node = expr.as()) { + if (const auto* var_node = call_node->op.as()) { + if (var_node == current_let_bound_var_) { + // Remember this is a recursive call to the let-rec bound function. + // The Annotator functor below will not record any dependency from the let-rec bound + // var to the expression so that the indexed graph is always a DAG. + VLOG(1) << "Remembering recursive call to '" << var_node->name_hint() << "'"; + rec_calls_->emplace(call_node); + } + } + } } - void VisitExpr_(const LetNode* let) override { + void VisitExpr_(const LetNode* let_node) override { auto pre_visit = [&](const LetNode* op) { - this->VisitSpan(op->span); - this->VisitExpr(op->value); - this->VisitExpr(op->var); + // Let-bound values come before their let-bound variable. + const VarNode* prev_let_bound_var = current_let_bound_var_; + current_let_bound_var_ = op->var.get(); + VisitExpr(op->value); + current_let_bound_var_ = prev_let_bound_var; + VisitExpr(op->var); }; auto post_visit = [&](const LetNode* op) { - this->VisitExpr(op->body); - if (let != op) { - Expr expr = GetRef(op); + VisitExpr(op->body); + if (let_node != op) { + // Replicate VisitLeaf, which we are effectively bypassing. visit_counter_[op]++; - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(GetRef(op)); } }; - ExpandANormalForm(let, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); } - IndexedGraph graph_; - size_t index_ = 0; + class PatternCreator : public PatternVisitor { + public: + explicit PatternCreator(Creator* creator) : creator_(creator) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + creator_->VisitLeaf(pattern_var_node->var); + } + + Creator* creator_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Matched data comes before match-bound vars then match rhs, in match order. + VisitExpr(match_node->data); + for (const Clause& c : match_node->clauses) { + PatternCreator pattern_creator(this); + pattern_creator.VisitPattern(c->lhs); + VisitExpr(c->rhs); + } + } + + /*! \brief Graph we are accumulated nodes into. */ + std::unique_ptr> graph_ = std::make_unique>(); + /*! \brief Variable the currently visited expression is to be let-bound to, if any. */ + const VarNode* current_let_bound_var_ = nullptr; + /*! \brief Accumulated calls to recursive functions. */ + std::unique_ptr> rec_calls_ = + std::make_unique>(); }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree - * analysis. + + /*! + * \brief Fills in the inputs and outputs for all nodes, then does dominator analysis. * - * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined - * topological order instead of recursing. + * Thought we use the ExprFunctor to visit nodes, we never recurse and instead just inspect + * each sub-expression's immediate sub-sub-expressions to accumulate inputs and outputs. */ - class Annotator : public ExprFunctor { + class Annotator : public ExprFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + explicit Annotator(std::pair>, + std::unique_ptr>> + args) + : graph_(std::move(args.first)), rec_calls_(std::move(args.second)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - ExprFunctor::VisitExpr(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitExpr(graph_->index_to_node(index)->ref()); } // do the dominator analysis - graph_.PostDom(); + graph_->PostDom(); return std::move(graph_); } - /*! Default visitation pushes the parent to the child's outputs and the child to the parent's - * inputs*/ - void VisitExpr(const Expr& expr, NodePtr parent) override { - auto current = graph_.node_map_[expr]; - if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); - } + /*! + * \brief Add \p parent as a possible output of the node corresponding to \p expr. + */ + void AddOutput(const Expr& expr, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(expr); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } protected: - IndexedGraph graph_; - void VisitExpr_(const VarNode* op, NodePtr parent) override { - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation); - } - } + void VisitExpr_(const VarNode* var_node) override {} - void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + void VisitExpr_(const GlobalVarNode* global_var_node) override {} - void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + void VisitExpr_(const ConstantNode* constant_node) override {} - void VisitExpr_(const TupleNode* op, NodePtr parent) override { - for (auto field : op->fields) { - this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const TupleNode* tuple_node) override { + auto node = graph_->item_to_node(GetRef(tuple_node)); + for (auto field : tuple_node->fields) { + AddOutput(field, node); } } - void VisitExpr_(const FunctionNode* op, NodePtr parent) override { - for (auto param : op->params) { - this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + auto node = graph_->item_to_node(GetRef(function_node)); + // Nothing to do for parameters -- each use of a parameter will contribute to its outputs. + AddOutput(function_node->body, node); + } + + void VisitExpr_(const CallNode* call_node) override { + auto node = graph_->item_to_node(GetRef(call_node)); + if (rec_calls_->count(call_node)) { + // We want the indexed graph to be a DAG, so don't consider a call to a let-rec bound + // function from inside the function to depend on the let-rec bound var. + VLOG(1) << "Ignoring op in call " << RefToSummary(GetRef(call_node)); + } else { + AddOutput(call_node->op, node); + } + for (auto arg : call_node->args) { + AddOutput(arg, node); } + } + + void VisitExpr_(const LetNode* let_node) override { + auto node = graph_->item_to_node(GetRef(let_node)); + auto let_var_node = graph_->item_to_node(let_node->var); + AddOutput(let_node->value, let_var_node); + // Nothing to do for the let-bound variable -- each use of that variable in the let-body + // will contribute to its outputs. + AddOutput(let_node->body, node); + } - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const IfNode* if_node) override { + auto node = graph_->item_to_node(GetRef(if_node)); + AddOutput(if_node->cond, node); + AddOutput(if_node->true_branch, node); + AddOutput(if_node->false_branch, node); } - void VisitExpr_(const CallNode* op, NodePtr parent) override { - this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const OpNode* op_node) override {} - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) override { + auto node = graph_->item_to_node(GetRef(tuple_get_item_node)); + AddOutput(tuple_get_item_node->tuple, node); + } + + void VisitExpr_(const RefCreateNode* ref_create_node) override { + auto node = graph_->item_to_node(GetRef(ref_create_node)); + AddOutput(ref_create_node->value, node); + } + + void VisitExpr_(const RefReadNode* ref_read_node) override { + auto node = graph_->item_to_node(GetRef(ref_read_node)); + AddOutput(ref_read_node->ref, node); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) override { + auto node = graph_->item_to_node(GetRef(ref_write_node)); + AddOutput(ref_write_node->ref, node); + AddOutput(ref_write_node->value, node); + } + + void VisitExpr_(const ConstructorNode* constructor_node) override {} + + class PatternAnnotator : public PatternVisitor { + public: + PatternAnnotator(Annotator* annotator, const ExprNode* adt_node) + : annotator_(annotator), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto node = annotator_->graph_->item_to_node(pattern_var_node->var); + annotator_->AddOutput(GetRef(adt_node_), node); } - for (auto arg : op->args) { - this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + Annotator* annotator_; + const ExprNode* adt_node_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Data flows from the match data to pattern vars into match arms and out into overall + // match. + auto node = graph_->item_to_node(GetRef(match_node)); + for (const Clause& c : match_node->clauses) { + PatternAnnotator pattern_annotator(this, match_node->data.get()); + pattern_annotator.VisitPattern(c->lhs); + AddOutput(c->rhs, node); } } - void VisitExpr_(const LetNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); - } + std::unique_ptr> graph_; + /*! \brief Accumulated calls to recursive functions. */ + std::unique_ptr> rec_calls_; + }; + + /*! \brief Fills in the basic blocks for all nodes. */ + class Blocker : public MixedModeVisitor { + public: + explicit Blocker(std::unique_ptr> graph) : graph_(std::move(graph)) {} - void VisitExpr_(const IfNode* op, NodePtr parent) override { - this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + std::unique_ptr> Scope(const Expr& expr) { + VisitExpr(expr); + return std::move(graph_); } - void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + private: + using MixedModeVisitor::VisitExpr_; - void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { - this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + SetScope(expr); } - void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + auto node = graph_->item_to_node(GetRef(function_node)); + basic_block_stack_.push_back(node); + ExprVisitor::VisitExpr_(function_node); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefReadNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const IfNode* if_node) override { + VisitExpr(if_node->cond); + auto node = graph_->item_to_node(GetRef(if_node)); + basic_block_stack_.push_back(node); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const LetNode* let_node) override { + auto pre_visit = [&](const LetNode* op) { + VisitExpr(op->value); + VisitExpr(op->var); + }; + auto post_visit = [&](const LetNode* op) { + VisitExpr(op->body); + if (let_node != op) { + visit_counter_[op]++; + SetScope(GetRef(op)); + } + }; + ExpandANormalForm(let_node, pre_visit, post_visit); } - void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { - for (const Type& t : op->inputs) { - this->VisitType(t); + class PatternBlocker : public PatternVisitor { + public: + explicit PatternBlocker(Blocker* scoper) : scoper_(scoper) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + scoper_->SetScope(pattern_var_node->var); } - this->VisitType(op->belong_to); - } - void VisitExpr_(const MatchNode* op, NodePtr parent) override { - this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); - for (const Clause& c : op->clauses) { - this->VisitClause(c, graph_.node_map_[GetRef(op)]); + Blocker* scoper_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + VisitExpr(match_node->data); + auto node = graph_->item_to_node(GetRef(match_node)); + basic_block_stack_.push_back(node); + for (const Clause& c : match_node->clauses) { + PatternBlocker pattern_scoper(this); + pattern_scoper.VisitPattern(c->lhs); + VisitExpr(c->rhs); } + basic_block_stack_.pop_back(); } - void VisitClause(const Clause& op, NodePtr parent) { - this->VisitPattern(op->lhs); - this->VisitExpr(op->rhs, parent); + void SetScope(const Expr& expr) { + auto node = graph_->item_to_node(expr); + if (!basic_block_stack_.empty()) { + node->basic_block_ = basic_block_stack_.back(); + } } - void VisitPattern(const Pattern& p) { return; } - - void VisitType(const Type& t) { return; } + std::unique_ptr> graph_; + std::vector::Node*> basic_block_stack_; }; - return Annotator(Creator().CreateGraph(expr)).Annotate(); + + VLOG(1) << "CreateIndexedGraph:" << std::endl << PrettyPrint(expr); + std::unique_ptr> graph = + Blocker(Annotator(Creator().CreateGraph(expr)).Annotate()).Scope(expr); + VLOG(1) << "graph:" << std::endl << graph->ToString(); +#if TVM_LOG_DEBUG + graph->CheckValid(); +#endif + return graph; } -IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern) { + /*! \brief Creates an IndexedGraph and determines topological order */ class Creator : public DFPatternVisitor { public: - IndexedGraph CreateGraph(const DFPattern& pattern) { + std::unique_ptr> CreateGraph(const DFPattern& pattern) { + graph_ = std::make_unique>(); VisitDFPattern(pattern); - graph_.node_map_[pattern]->is_external_ = true; + graph_->item_to_node(pattern)->is_external_ = true; return std::move(graph_); } @@ -215,121 +413,135 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern(const DFPattern& pattern) override { if (this->visited_.count(pattern.get()) == 0) { DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(pattern); } } - IndexedGraph graph_; - size_t index_ = 0; + + std::unique_ptr> graph_; }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree * analysis. * * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined * topological order instead of recursing. */ - class Annotator : public DFPatternFunctor { + class Annotator : public DFPatternFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + Annotator(std::unique_ptr> graph) : graph_(std::move(graph)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitDFPattern(graph_->index_to_node(index)->ref()); } - graph_.PostDom(); // do the dominator analysis + graph_->PostDom(); return std::move(graph_); } /*! Default visitation pushes the parent to the child's outputs */ - void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { - auto current = graph_.node_map_[pattern]; + void AddOutput(const DFPattern& pattern, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(pattern); if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } } protected: - IndexedGraph graph_; - void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AltPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->left, node); + AddOutput(op->right, node); } - void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AttrPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const CallPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->op, node); if (op->args.defined()) { for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + AddOutput(arg, node); } } } - void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ConstantPatternNode* op) override {} - void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DataTypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DominatorPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->parent, node); + AddOutput(op->path, node); + AddOutput(op->child, node); } - void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ExprPatternNode* op) override {} - void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const FunctionPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->params.defined()) { for (auto param : op->params) { - VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + AddOutput(param, node); } } - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + AddOutput(op->body, node); } - void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const ShapePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TupleGetItemPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->tuple, node); } - void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const TuplePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->fields.defined()) { for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + AddOutput(field, node); } } } - void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const IfPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->cond, node); + AddOutput(op->true_branch, node); + AddOutput(op->false_branch, node); } - void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->var, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->value, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const LetPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->var, node); + AddOutput(op->value, node); + AddOutput(op->body, node); } - void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const VarPatternNode* op) override {} - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const WildcardPatternNode* op) override {} + + std::unique_ptr> graph_; }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); } diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d073bcaeea5c9..c1ce53f40da3d 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -19,7 +19,12 @@ /*! * \file src/relay/ir/indexed_graph.h - * \brief A pattern matcher for matching dataflow properties. + * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) + * pattern. Each 'indexed graph' node is 1:1 with an expression/pattern 'node', hence the + * term 'IndexedGraph'. Dataflow is captured in a generic representation which is convenient + * for analysis, particularly pattern matching and partitioning. + * + * TODO(mbs): Copied from fuse_ops.cc, consider refactoring to share implementation. */ #ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ #define TVM_RELAY_IR_INDEXED_GRAPH_H_ @@ -28,6 +33,7 @@ #include #include +#include #include #include #include @@ -36,47 +42,108 @@ namespace tvm { namespace relay { +/*! \brief The index of a node in the post-dfs traversal of overall expression. */ +using PostDfsIndex = size_t; + +/*! + * \brief Returns a brief summary of the 'reference' expression or pattern. Only used by + * IndexedGraph::ToString() for debugging. + */ +std::string RefToSummary(const Expr& expr); +std::string RefToSummary(const DFPattern& pattern); + /*! - * \brief A Wrapper around a templated graph type - * Holds a forward-backward indexed representation of the graph and a dominator tree representation - * of the graph + * \brief Represents the implied dataflow of an expression or (dataflow) pattern as a DAG who's + * nodes are 1:1 with those in the underlying expression/pattern. + * + * Each indexed graph node captures: + * - Dataflow inputs. + * - Dataflow outputs (or a flag indicating the node is an implied output). + * - Dominator parent (ie closest node at which all outputs of the current node re-combine). + * - Dominator children (inverse of above). + * - Basic block (ie node representing the body of a function, arm of an if, etc). * - * This class is templated and the implementaiton is in the header file so we can analyze both - * DFPattern and Expr with the same infrastructure. + * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure. * - * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities below. */ template class IndexedGraph { public: - /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + using TNode = typename T::ContainerType; + + /*! \brief A Node in the graph. */ struct Node { /*! \brief Node Constructor - * \param ref The input graph node - * \param index The index of the node in toplogical order + * \param ref The expression or dataflow pattern node this indexed graph node is augmenting. + * \param index The index of this node in the topological order */ - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + Node(const TNode* ref, PostDfsIndex index) : node_ref_(ref), index_(index) {} + + /*! \brief The underlying expression or pattern node. */ + const TNode* node_ref_; - /*! \brief The input node */ - const T ref_; - /*! \brief The topological order index */ - const size_t index_; + T ref() const { + ICHECK(node_ref_ != nullptr); + return GetRef(node_ref_); + } + + /*! + * \brief The index of this node in post-dfs order. If left.index_ > right.index_ then + * left does not flow into right. If left.index_ = right.index_ then left and right are + * the same node. + */ + const PostDfsIndex index_; - /*! \brief A boolean to determine if this node is external to the graph */ + /*! \brief If true this node has implicit outputs, for example as the result of a function. */ bool is_external_ = false; - /*! \brief The forward inputs of the node */ + /*! \brief Immediate dataflow inputs to this node. */ std::vector inputs_; - /*! \brief The forward outputs/users of the node */ + /*! \brief Immediate dataflow outputs of this node -- may be empty if is_external_ is true. */ std::vector outputs_; - /*! \brief The depth of the node in the dominator tree */ + /*! + * \brief The node representing the 'basic block' containing this node: + * - Function bodies start a new basic block for their bodies. + * - The true and false branches of an if start their own blocks. + * - The arms of a match each have their own blocks. + */ + Node* basic_block_ = nullptr; + + /*! \brief The depth of this node in the dominator tree */ size_t depth_ = 0; - /*! \brief The dominator parent/final user of the outputs of this node */ - Node* dominator_parent_; - /*! \brief The nodes this node dominates */ + /*! + * \brief The dominator parent of this node. This is the node N with least index such that + * all possible dataflows from this node pass through N. + */ + Node* dominator_parent_ = nullptr; + /*! \brief The nodes this node dominates. */ std::vector dominator_children_; - bool Dominates(const Node* other) { + /*! + * Add to \p nodes all the nodes which are strictly downstream of \p this, ie can be + * reached by following output paths. + */ + void AccumulateDownstreamNodes(std::unordered_set* nodes) const { + std::stack stack; + stack.push(this); + while (!stack.empty()) { + const Node* current = stack.top(); + stack.pop(); + for (auto node : current->outputs_) { + if (nodes->count(node) == 0) { + stack.push(node); + nodes->insert(node); + } + } + } + } + + /*! + * \brief Returns true if \p this is a dominator of \p other. Ie all dataflow paths from \p + * other pass through \p this. + */ + bool Dominates(const Node* other) const { std::stack stack; std::unordered_set visited; stack.push(this); @@ -97,10 +164,125 @@ class IndexedGraph { return false; } }; + + PostDfsIndex size() const { return topological_order_.size(); } + + Node* item_to_node(const T& item) { return item_to_node(item.get()); } + const Node* item_to_node(const T& item) const { return item_to_node(item.get()); } + + Node* item_to_node(const TNode* item) { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + const Node* item_to_node(const TNode* item) const { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + Node* index_to_node(PostDfsIndex index) { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + const Node* index_to_node(PostDfsIndex index) const { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + /*! + * \brief (For debugging only) Returns description of indexed graph with hints as to the + * sub-expressions or sub-patterns corresponding to each indexed graph node. + */ + std::string ToString() const { + std::ostringstream os; + os << "IndexedGraph(size = " << topological_order_.size() << ") {" << std::endl; + for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { + const Node* node = topological_order_[index].get(); + ICHECK_EQ(index, node->index_); + os << " " << index << " (" << RefToSummary(node->ref()) << "): inputs=["; + for (const auto* sub_node : node->inputs_) { + os << sub_node->index_ << ","; + } + os << "], outputs=["; + for (const auto* sub_node : node->outputs_) { + os << sub_node->index_ << ","; + } + os << "]"; + if (node->is_external_) { + os << ", external"; + } + if (node->basic_block_) { + os << ", basic_block=" << node->basic_block_->index_; + } + if (node->depth_ > 0) { + os << ", depth=" << node->depth_; + } + if (node->dominator_parent_) { + os << ", dom_parent=" << node->dominator_parent_->index_; + } + os << ", dom_children=["; + for (const auto* sub_node : node->dominator_children_) { + os << sub_node->index_ << ","; + } + os << "]" << std::endl; + } + os << "}"; + return os.str(); + } + + /*! + * Check-fails if the graph is ill-formed. For debugging only. + */ + void CheckValid() const { + ICHECK_GT(topological_order_.size(), 0); + for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { + const Node* node = topological_order_[index].get(); + // We have a node. + ICHECK(node); + // Bijections with post-dfs indexes and expressions/patterns are correct. + ICHECK_EQ(node->index_, index); + ICHECK(node->node_ref_); + auto itr = node_map_.find(node->node_ref_); + ICHECK(itr != node_map_.end()); + ICHECK_EQ(itr->second, node) << "at index " << index << " in:" << std::endl << ToString(); + // Inputs come before. + for (size_t i = 0; i < node->inputs_.size(); ++i) { + const Node* input = node->inputs_[i]; + ICHECK(input); + ICHECK_LT(input->index_, index); + ICHECK(std::find(input->outputs_.begin(), input->outputs_.end(), node) != + input->outputs_.end()); + } + // Outputs come after. + for (size_t i = 0; i < node->outputs_.size(); ++i) { + const Node* output = node->outputs_[i]; + ICHECK(output); + ICHECK_GT(output->index_, index); + ICHECK(std::find(output->inputs_.begin(), output->inputs_.end(), node) != + output->inputs_.end()); + } + ICHECK_GT(node->depth_, 0); + // Dominator children come before. + for (size_t i = 0; i < node->dominator_children_.size(); ++i) { + const Node* child = node->dominator_children_[i]; + ICHECK(child); + ICHECK_LT(child->index_, index); + } + if (node->dominator_parent_) { + // Dominator comes after. + ICHECK_GT(node->dominator_parent_->index_, index); + } + } + } + + private: /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { - for (size_t i = topological_order_.size(); i != 0; --i) { - size_t index = i - 1; + for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { + PostDfsIndex index = i - 1; auto* current = topological_order_[index].get(); if (current->is_external_) { current->depth_ = 1; @@ -109,16 +291,13 @@ class IndexedGraph { auto parent = LeastCommonAncestor(current->outputs_); current->depth_ = parent ? parent->depth_ + 1 : 1; current->dominator_parent_ = parent; - parent->dominator_children_.push_back(current); + if (parent) { + parent->dominator_children_.push_back(current); + } } } } - /*! \brief Map of input nodes to IndexedGraph Nodes */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; - /*! \brief Topological IndexedGraph Nodes */ - std::vector> topological_order_; - protected: /*! \brief Find the least common ancestor of all outputs of a node */ Node* LeastCommonAncestor(const std::vector& outputs) { if (outputs.size() == 0) { @@ -136,9 +315,11 @@ class IndexedGraph { if (lhs == nullptr || rhs == nullptr) { return nullptr; } + PostDfsIndex lhs_index = lhs->index_; + PostDfsIndex rhs_index = rhs->index_; while (lhs != rhs) { - ICHECK(lhs); - ICHECK(rhs); + ICHECK(lhs && rhs) << "LCA(" << lhs_index << ", " << rhs_index << ") on graph:" << std::endl + << ToString(); if (lhs->depth_ < rhs->depth_) { rhs = rhs->dominator_parent_; } else if (lhs->depth_ > rhs->depth_) { @@ -150,13 +331,41 @@ class IndexedGraph { } return lhs; } + + /*! + * \brief Appends a node corresponding to \p ref, and maintains the sub-expression/sub-pattern to + * node bijection. The insertion index will be the node's PostDfsIndex. All other node properties + * are accumulated in-place. + */ + void AddNode(const T& ref) { + PostDfsIndex index = topological_order_.size(); + auto node = std::make_unique(ref.get(), index); + node_map_[ref.get()] = node.get(); + topological_order_.emplace_back(std::move(node)); + } + + /*! + * \brief Map from underlying sub-expression or sub-pattern nodes to their indexed graph nodes. + */ + std::unordered_map node_map_; + /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */ + std::vector> topological_order_; + + friend std::unique_ptr> CreateIndexedGraph(const Expr& expr); + friend std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); }; -/*! \brief Create an Indexed Graph based on an Expr */ -IndexedGraph CreateIndexedGraph(const Expr& expr); -/*! \brief Create an Indexed Graph based on an DFPattern */ -IndexedGraph CreateIndexedGraph(const DFPattern& pattern); +/*! \brief Returns an Indexed Graph for \p expr, which much outlive the result. */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr); + +/*! + * \brief Returns an Indexed Graph for \p pattern, which must outlive the result. + * The dataflow for a pattern mimics the dataflow for the expression which would match + * that pattern. + */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm + #endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index f7045305e90d3..d5cc6608662b2 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -258,6 +258,7 @@ RELAY_REGISTER_OP("dyn.broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) .set_num_inputs(2) + .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Tensor", "Target shape.") .set_support_level(4) diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc new file mode 100644 index 0000000000000..17ec682616843 --- /dev/null +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/relay/ir/indexed_graph.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace { + +// A module stolen from onnx/test_forward.py::test_loop which combines functions, recursion, +// control flow, tuples as well as the usual operator calls. +// We include the known post-dfs indexes in comments to help write the tests. +IRModule TestRecursiveIRModule() { + Device device = {kDLCPU, 0}; + Constant const0(runtime::NDArray::Empty(ShapeTuple({1}), DataType::Int(64), device)); + Constant const1(runtime::NDArray::Empty(ShapeTuple({0, 1}), DataType::Float(32), device)); + Map> metadata; + metadata.Set("relay.Constant", Array({const0, const1})); + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%trip_count: int64, // 0 + %cond: bool, // 1 + %y: Tensor[(1), float32]) // 2 + -> (Tensor[(1), float32], Tensor[(?, ?), float32]) { + %17 = ( + let %while_loop = fn (%iter_count: int64, // 3 + %max_count: int64, // 4 + %cond_in: bool, // 5 + %y_in: Tensor[(1), float32], // 6 + %scan_out: Tensor[(?, ?), float32]) // 7 + -> (int64, int64, bool, Tensor[(1), float32], Tensor[(?, ?), float32]) { + %0 = equal(%cond_in, True); // 11 + %1 = less(%iter_count, %max_count); // 13 + %2 = logical_and(%0, %1); // 14 + if (%2) { + %3 = cast(%iter_count, dtype="float32"); // 20 + %4 = add(%y_in, %3); // 21 + %5 = less(%4, 5f); // 23 + %6 = squeeze(%5); // 24 + %7 = reshape(%iter_count, newshape=[1]); // 29 + %8 = (%7, meta[relay.Constant][0]); // 31 + %9 = concatenate(%8); // 32 + %10 = copy(%4); // 36 + %11 = dyn.broadcast_to(%scan_out, %9, shape=None); // 33 + %12 = expand_dims(%10, axis=0); // 37 + %13 = (%11, %12); // 38 + %14 = add(%iter_count, 1i64); // 17 + %15 = cast(%6, dtype="bool"); // 25 + %16 = concatenate(%13); // 39 + %while_loop(%14, %max_count, %15, %4, %16) // 40 + } else { + (%iter_count, %max_count, %cond_in, %y_in, %scan_out) // 41 + } // 42 + }; // 43 + %while_loop // 44 + ); // 45 + %18 = %17(0i64, %trip_count, %cond, %y, meta[relay.Constant][1]); // 48 + %19 = %18.3; // 49 + %20 = %18.4; // 50 + (%19, %20) // 51 + } // 52 + )"; + return parser::ParseModule("string", kModel, /*init_module=*/{}, metadata); +} + +TEST(IndexedGraph, RecursiveExprRegression) { + IRModule ir_mod = TestRecursiveIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = CreateIndexedGraph(main); + graph->CheckValid(); + + { + // Dataflow node properties for %4 + auto node = graph->index_to_node(21); + const auto* call_node = node->ref().as(); + ASSERT_NE(call_node, nullptr); + const auto* op_node = call_node->op.as(); + ASSERT_NE(op_node, nullptr); + ASSERT_EQ(op_node->name, "add"); + + // 3 inputs (the op itself is an input) + ASSERT_EQ(node->inputs_.size(), 3); + ASSERT_EQ(node->inputs_[0]->index_, 15); // the add op + ASSERT_EQ(node->inputs_[1]->index_, 6); // %y_in + ASSERT_EQ(node->inputs_[2]->index_, 20); // %3 + + // 3 outputs + ASSERT_EQ(node->outputs_.size(), 3); + ASSERT_EQ(node->outputs_[0]->index_, 23); // %5 + ASSERT_EQ(node->outputs_[1]->index_, 36); // %10 + ASSERT_EQ(node->outputs_[2]->index_, 40); // recursive %while_loop call + + // In the 'if' basic block + ASSERT_EQ(node->basic_block_->index_, 42); + + // Dominator 'parent' is recursive call + ASSERT_EQ(node->dominator_parent_->index_, 40); + + // One dominator child from %3 + ASSERT_EQ(node->dominator_children_.size(), 1); + ASSERT_EQ(node->dominator_children_[0]->index_, 20); + } + + { + // The recursive call to %while_loop does not depend on %while_loop + auto node = graph->index_to_node(40); + const auto* call_node = node->ref().as(); + ASSERT_NE(call_node, nullptr); + const auto* var_node = call_node->op.as(); + ASSERT_NE(var_node, nullptr); + ASSERT_EQ(var_node->name_hint(), "while_loop"); + + ASSERT_EQ(node->inputs_.size(), 5); + ASSERT_EQ(node->inputs_[0]->index_, 17); // %14 + ASSERT_EQ(node->inputs_[1]->index_, 4); // %max_count + ASSERT_EQ(node->inputs_[2]->index_, 25); // %15 + ASSERT_EQ(node->inputs_[3]->index_, 21); // %4 + ASSERT_EQ(node->inputs_[4]->index_, 39); // %16 + } + + { + // Downstream nodes of %18 + auto node = graph->index_to_node(48); + std::unordered_set::Node*> downstreams; + node->AccumulateDownstreamNodes(&downstreams); + ASSERT_EQ(downstreams.size(), 4); + for (const auto* downstream : downstreams) { + ASSERT_TRUE(downstream->index_ >= 49 && downstream->index_ <= 52); + } + } + + { + // Dominates relation for %4 + auto upstream = graph->index_to_node(21); + // Path 1: 21->23->24->25->40 + // Path 2: 21->36->37->38->39->40 + // Then 40->43 + auto downstream = graph->index_to_node(43); + ASSERT_TRUE(downstream->Dominates(upstream)); + } +} + +// A module with unused let-bound function. The 'add' operator should have no dominator +// since it is used both in the unused function and in the main body. +IRModule TestUnusedLetBoundIRModule() { + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%x: int64) -> int64 { // 0 + let %f = fn ( // 5 + %y: int64 // 1 + ) { + add(%x, %y) // 3 + }; + if (less(%x, 5i64)) { + add(%x, 3i64) // 10 + } else { + %x + } + } + )"; + return parser::ParseModule("string", kModel); +} + +TEST(IndexedGraph, UnusedLetVars) { + IRModule ir_mod = TestUnusedLetBoundIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = CreateIndexedGraph(main); + graph->CheckValid(); + + { + auto node = graph->index_to_node(2); + const auto* op_node = node->ref().as(); + ICHECK(op_node); + ICHECK_EQ(op_node->name, "add"); + ICHECK_EQ(node->outputs_.size(), 2); + ICHECK_EQ(node->outputs_[0]->index_, 3); + ICHECK_EQ(node->outputs_[1]->index_, 10); + ICHECK(node->dominator_parent_ == nullptr); + } +} + +} // namespace +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 74e03f6a97551..f0474c9112736 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=unused-wildcard-import import numpy as np -import pytest import tvm from tvm import relay @@ -601,6 +600,38 @@ def test_match_fake_diamond(): assert not diamond.match(out) +def test_at_most_one_parent(): + # Pattern + P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent' + I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code) + C = is_op("add")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # n6(P) + # / \ + # n7 \ + # / \ + # n8(P) n10(I) + # \ / + # n9(I) / + # \ / + # n11(C) + + x = relay.var("x") + w = relay.var("w") + n6 = relay.op.nn.conv2d(x, w) # matches P + n7 = relay.op.tanh(n6) # does not match I + n8 = relay.op.nn.conv2d(n7, w) # matches P + n9 = relay.op.nn.relu(n8) # matches I + n10 = relay.op.nn.relu(n6) # matches I + n11 = relay.add(n9, n10) # matches C + + # Does not match: Can't match the parent pattern P at both 8 and 6. + # Note that if we did allow P to be used twice the implementation would + # need to be changed to not 'jump over' n7. + assert not pattern.match(n11) + + def test_match_dominator(): # Pattern is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) @@ -1760,4 +1791,4 @@ def callback(self, pre, post, node_map): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() From 774ee969fcb19e9d16e74de77c64848fd30e9a52 Mon Sep 17 00:00:00 2001 From: Christian Convey Date: Tue, 7 Jun 2022 18:16:25 -0400 Subject: [PATCH 061/181] [relay] add missing virtual d'tor (#11601) Add a default virtual destructor to `tvm::relay::transforms::GlobalSymbolCache`, so that correct destructors run when destroying subclass instances. --- src/relay/transforms/compiler_function_utils.cc | 2 ++ src/relay/transforms/compiler_function_utils.h | 1 + 2 files changed, 3 insertions(+) diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index b98d089b346a3..f22e9bd80dd07 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -119,6 +119,8 @@ class CallRewriter : public MixedModeMutator { } // namespace +GlobalSymbolCache::~GlobalSymbolCache() = default; + GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(opt_global_symbol.defined()) diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 7b5143444bf8a..e4b1f05211fe1 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -71,6 +71,7 @@ namespace transforms { */ class GlobalSymbolCache { public: + virtual ~GlobalSymbolCache(); virtual GlobalVar GetGlobalSymbol(const Function& function) = 0; }; From d490620085792f802d606209008698e65fb12c0e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 7 Jun 2022 17:16:37 -0500 Subject: [PATCH 062/181] [Hexagon][CI] Re-enable Hexagon tests in CI (#11613) * [Hexagon][CI] Re-enable Hexagon tests in CI These were enabled in https://github.com/apache/tvm/pull/11294, then erroneously disabled in https://github.com/apache/tvm/pull/11313. This applies the same fix as in https://github.com/apache/tvm/pull/11294, checking the `ANDROID_SERIAL_NUMBER` to determine if Hexagon tests can execute at runtime, but using the refactored `pytest.skipif` messages introduced in https://github.com/apache/tvm/pull/11313. * Fixed circular dependency, but feels somewhat ugly --- python/tvm/contrib/hexagon/_ci_env_check.py | 62 +++++++++++++++++++++ python/tvm/contrib/hexagon/pytest_plugin.py | 10 +--- python/tvm/testing/utils.py | 8 +-- 3 files changed, 66 insertions(+), 14 deletions(-) create mode 100644 python/tvm/contrib/hexagon/_ci_env_check.py diff --git a/python/tvm/contrib/hexagon/_ci_env_check.py b/python/tvm/contrib/hexagon/_ci_env_check.py new file mode 100644 index 0000000000000..c1c70750e86ae --- /dev/null +++ b/python/tvm/contrib/hexagon/_ci_env_check.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Hexagon environment checks for CI usage + +These may be required by either tvm.testing or +tvm.contrib.hexagon.pytest_plugin, and are separated here to avoid a +circular dependency. +""" + +import os + +import tvm + +ANDROID_SERIAL_NUMBER = "ANDROID_SERIAL_NUMBER" +HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN" + + +def _compile_time_check(): + """Return True if compile-time support for Hexagon is present, otherwise + error string. + + Designed for use as a the ``compile_time_check`` argument to + `tvm.testing.Feature`. + """ + if ( + tvm.testing.utils._cmake_flag_enabled("USE_LLVM") + and tvm.target.codegen.llvm_version_major() < 7 + ): + return "Hexagon requires LLVM 7 or later" + + if "HEXAGON_TOOLCHAIN" not in os.environ: + return f"Missing environment variable {HEXAGON_TOOLCHAIN}." + + return True + + +def _run_time_check(): + """Return True if run-time support for Hexagon is present, otherwise + error string. + + Designed for use as a the ``run_time_check`` argument to + `tvm.testing.Feature`. + """ + if ANDROID_SERIAL_NUMBER not in os.environ: + return f"Missing environment variable {ANDROID_SERIAL_NUMBER}." + + return True diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py index 2c62a0a0b5694..278bd833da954 100644 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ b/python/tvm/contrib/hexagon/pytest_plugin.py @@ -53,15 +53,7 @@ def _compose(args, decs): return decs -def requires_hexagon_toolchain(*args): - _requires_hexagon_toolchain = [ - pytest.mark.skipif( - os.environ.get(HEXAGON_TOOLCHAIN) is None, - reason=f"Missing environment variable {HEXAGON_TOOLCHAIN}.", - ), - ] - - return _compose(args, _requires_hexagon_toolchain) +requires_hexagon_toolchain = tvm.testing.requires_hexagon(support_required="compile-only") @tvm.testing.fixture diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 939786c9294fc..bf3cc94f5ddf7 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -88,6 +88,7 @@ def test_something(): import tvm._ffi from tvm.contrib import nvcc, cudnn +import tvm.contrib.hexagon._ci_env_check as hexagon from tvm.error import TVMError @@ -937,11 +938,8 @@ def _any_gpu_exists(): "Hexagon", cmake_flag="USE_HEXAGON", target_kind_enabled="hexagon", - compile_time_check=lambda: ( - (_cmake_flag_enabled("USE_LLVM") and tvm.target.codegen.llvm_version_major() >= 7) - or "Hexagon requires LLVM 7 or later" - ), - target_kind_hardware="hexagon", + compile_time_check=hexagon._compile_time_check, + run_time_check=hexagon._run_time_check, parent_features="llvm", ) From 52d90da1d3bc6b12611b1d30a38c02837fbf8d76 Mon Sep 17 00:00:00 2001 From: "Kathryn (Jinqi) Chen" <65606304+Kathryn-cat@users.noreply.github.com> Date: Tue, 7 Jun 2022 18:05:14 -0700 Subject: [PATCH 063/181] [MetaSchedule] TuningRecord Optional Arguments (#11598) In some situations, such as before measuring the candidates, the arguments `run_secs`, `target`, and `args_info` in `TuningRecord` are not required. Per this request, the new `TuningRecord` API now accepts arguments in the order of `trace, workload, run_secs, target, args_info` with the last three being optional. Note that some tests might fail due to the change of argument order, so they might need to be adjusted accordingly. --- include/tvm/meta_schedule/database.h | 17 +++--- python/tvm/meta_schedule/database/database.py | 26 ++++----- python/tvm/meta_schedule/testing/utils.py | 2 +- src/meta_schedule/database/database.cc | 54 ++++++++++++------- src/meta_schedule/database/json_database.cc | 4 +- .../measure_callback/add_to_database.cc | 2 +- .../unittest/test_meta_schedule_database.py | 26 ++++----- .../test_meta_schedule_integration.py | 2 +- .../unittest/test_meta_schedule_tune_relay.py | 2 +- 9 files changed, 75 insertions(+), 60 deletions(-) diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 1353dec3eda3f..37a315bf744e9 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -103,19 +103,19 @@ class TuningRecordNode : public runtime::Object { public: /*! \brief The trace tuned. */ tir::Trace trace; - /*! \brief The profiling result in seconds. */ - Array run_secs; /*! \brief The workload. */ Workload workload{nullptr}; + /*! \brief The profiling result in seconds. */ + Optional> run_secs; /*! \brief The target for tuning. */ - Target target; + Optional target; /*! \brief The argument information. */ - Array args_info; + Optional> args_info; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("trace", &trace); - v->Visit("run_secs", &run_secs); v->Visit("workload", &workload); + v->Visit("run_secs", &run_secs); v->Visit("target", &target); v->Visit("args_info", &args_info); } @@ -140,13 +140,14 @@ class TuningRecord : public runtime::ObjectRef { /*! \brief Constructor of a tuning record. \param trace The trace of the tuning record. - \param run_secs The running time of the tuning record. \param workload The workload of the tuning record. + \param run_secs The running time of the tuning record. \param target The target of the tuning record. \param args_info The argument information of the tuning record. */ - TVM_DLL explicit TuningRecord(tir::Trace trace, Array run_secs, Workload workload, - Target target, Array args_info); + TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload, + Optional> run_secs, Optional target, + Optional> args_info); /*! * \brief Create a tuning record from a json object. * \param json_obj The json object. diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 314bf434c417f..8e0c805410204 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Tuning record database""" -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional from tvm._ffi import register_object from tvm.ir.module import IRModule @@ -82,35 +82,35 @@ class TuningRecord(Object): ---------- trace : tvm.ir.Trace The trace of the tuning record. - run_secs : List[float] - The run time of the tuning record. workload : Workload The workload of the tuning record. - target : Target + run_secs : Optional[List[float]] + The run time of the tuning record. + target : Optional[Target] The target of the tuning record. - args_info : List[ArgInfo] + args_info : Optional[List[ArgInfo]] The argument information of the tuning record. """ trace: Trace - run_secs: List[float] workload: Workload - target: Target - args_info: List[ArgInfo] + run_secs: Optional[List[float]] + target: Optional[Target] + args_info: Optional[List[ArgInfo]] - def __init__( + def __init__( # type: ignore # pylint: disable=too-many-arguments self, trace: Trace, - run_secs: List[float], workload: Workload, - target: Target, - args_info: List[ArgInfo], + run_secs: Optional[List[float]] = None, + target: Optional[Target] = None, + args_info: Optional[List[ArgInfo]] = None, ) -> None: self.__init_handle_by_constructor__( _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member trace, - run_secs, workload, + run_secs, target, args_info, ) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index a832dfc6bcc4a..62950fdd0bb4a 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -155,7 +155,7 @@ def apply_fixed_schedules( if schedule_fn(task, sch): workload = database.commit_workload(mod) - tune_rec = TuningRecord(sch.trace, [0.0], workload, target, []) + tune_rec = TuningRecord(sch.trace, workload, [0.0], target, []) database.commit_tuning_record(tune_rec) return database diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index fc7cc74de5c67..86d999e4fdf59 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -74,48 +74,62 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { /******** TuningRecord ********/ -TuningRecord::TuningRecord(tir::Trace trace, Array run_secs, Workload workload, - Target target, Array args_info) { +TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optional> run_secs, + Optional target, Optional> args_info) { ObjectPtr n = make_object(); n->trace = trace; - n->run_secs = run_secs; n->workload = workload; + n->run_secs = run_secs; n->target = target; n->args_info = args_info; this->data_ = n; } ObjectRef TuningRecordNode::AsJSON() const { - Array json_args_info; - json_args_info.reserve(args_info.size()); - for (const ArgInfo& arg_info : args_info) { - json_args_info.push_back(arg_info->AsJSON()); + Optional> json_args_info{nullptr}; + Optional json_target{nullptr}; + if (args_info.defined()) { + Array info; + info.reserve(args_info.value().size()); + for (const ArgInfo& arg_info : args_info.value()) { + info.push_back(arg_info->AsJSON()); + } + json_args_info = info; + } + if (target.defined()) { + json_target = target.value()->Export(); } return Array{trace->AsJSON(false), // run_secs, // - target->Export(), // + json_target, // json_args_info}; } TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { tir::Trace trace{nullptr}; - Array run_secs{nullptr}; - Target target{nullptr}; - Array args_info; + Optional> run_secs{nullptr}; + Optional target{nullptr}; + Optional> args_info{nullptr}; try { const ArrayNode* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 4); // Load json[1] => run_secs - run_secs = Downcast>(json_array->at(1)); + if (json_array->at(1).defined()) { + run_secs = Downcast>(json_array->at(1)); + } // Load json[2] => target - target = Target(Downcast>(json_array->at(2))); + if (json_array->at(2).defined()) { + target = Target(Downcast>(json_array->at(2))); + } // Load json[3] => args_info - { + if (json_array->at(3).defined()) { const ArrayNode* json_args_info = json_array->at(3).as(); - args_info.reserve(json_args_info->size()); + Array info; + info.reserve(json_args_info->size()); for (const ObjectRef& json_arg_info : *json_args_info) { - args_info.push_back(ArgInfo::FromJSON(json_arg_info)); + info.push_back(ArgInfo::FromJSON(json_arg_info)); } + args_info = info; } // Load json[0] => trace { @@ -130,7 +144,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj << "\nThe error is: " << e.what(); } - return TuningRecord(trace, run_secs, workload, target, args_info); + return TuningRecord(trace, workload, run_secs, target, args_info); } /******** PyDatabase ********/ @@ -161,9 +175,9 @@ TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON") .set_body_method(&WorkloadNode::AsJSON); TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") - .set_body_typed([](tir::Trace trace, Array run_secs, Workload workload, Target target, - Array args_info) { - return TuningRecord(trace, run_secs, workload, target, args_info); + .set_body_typed([](tir::Trace trace, Workload workload, Optional> run_secs, + Optional target, Optional> args_info) { + return TuningRecord(trace, workload, run_secs, target, args_info); }); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 2e76940feee39..155d223217da9 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -40,8 +40,8 @@ struct SortTuningRecordByMeanRunSecs { } bool operator()(const TuningRecord& a, const TuningRecord& b) const { - double a_time = Mean(a->run_secs); - double b_time = Mean(b->run_secs); + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); return a_time < b_time; } }; diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 0988da0414e2a..27b4e55a7de5b 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -47,8 +47,8 @@ class AddToDatabaseNode : public MeasureCallbackNode { } database->CommitTuningRecord(TuningRecord( /*trace=*/candidate->sch->trace().value(), - /*run_secs=*/run_secs, /*workload=*/workload, + /*run_secs=*/run_secs, /*target=*/target, /*args_info=*/candidate->args_info)); } diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index d494f997c1ce7..1edfbe6c7a782 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -115,8 +115,8 @@ def test_meta_schedule_tuning_record_round_trip(): workload = database.commit_workload(mod) record = TuningRecord( _create_schedule(mod, _schedule_matmul).trace, - [1.5, 2.5, 1.8], workload, + [1.5, 2.5, 1.8], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ) @@ -140,8 +140,8 @@ def test_meta_schedule_database_has_workload(): workload = database.commit_workload(mod) record = TuningRecord( _create_schedule(mod, _schedule_matmul).trace, - [1.5, 2.5, 1.8], workload, + [1.5, 2.5, 1.8], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ) @@ -158,8 +158,8 @@ def test_meta_schedule_database_add_entry(): workload = database.commit_workload(mod) record = TuningRecord( _create_schedule(mod, _schedule_matmul).trace, - [1.5, 2.5, 1.8], workload, + [1.5, 2.5, 1.8], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ) @@ -178,8 +178,8 @@ def test_meta_schedule_database_missing(): workload_2 = database.commit_workload(mod_2) record = TuningRecord( _create_schedule(mod, _schedule_matmul).trace, - [1.5, 2.5, 1.8], workload, + [1.5, 2.5, 1.8], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ) @@ -197,43 +197,43 @@ def test_meta_schedule_database_sorting(): records = [ TuningRecord( trace, - [7.0, 8.0, 9.0], token, + [7.0, 8.0, 9.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [1.0, 2.0, 3.0], token, + [1.0, 2.0, 3.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [4.0, 5.0, 6.0], token, + [4.0, 5.0, 6.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [1.1, 1.2, 600.0], token, + [1.1, 1.2, 600.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [1.0, 100.0, 6.0], token, + [1.0, 100.0, 6.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [4.0, 9.0, 8.0], token, + [4.0, 9.0, 8.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), @@ -259,22 +259,22 @@ def test_meta_schedule_database_reload(): records = [ TuningRecord( trace, - [7.0, 8.0, 9.0], token, + [7.0, 8.0, 9.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [1.0, 2.0, 3.0], token, + [1.0, 2.0, 3.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, - [4.0, 5.0, 6.0], token, + [4.0, 5.0, 6.0], tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index a423bdb48afdf..3b33039bd2874 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -267,7 +267,7 @@ def test_meta_schedule_integration_apply_history_best(): target = Target("llvm") workload = database.commit_workload(MockModule) database.commit_tuning_record( - TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) + TuningRecord(Schedule(MockModule).trace, workload, [1.0], target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) assert tvm.ir.structural_equal(mod, workload.mod) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index e5076af520f30..e0883dbd227ed 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -307,8 +307,8 @@ def test_meta_schedule_relay_lowering(): database.commit_tuning_record( TuningRecord( Trace([], {}), - [0.0], database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc), + [0.0], target=target, args_info=[], ) From f5f9600614c4aa933c863001459d92b13d9b72fc Mon Sep 17 00:00:00 2001 From: "Ehsan M. Kermani" <6980212+ehsanmok@users.noreply.github.com> Date: Tue, 7 Jun 2022 22:08:28 -0700 Subject: [PATCH 064/181] [docs] Various content corrections (#11517) * [docs] Various content corrections * Fix underline title --- gallery/how_to/deploy_models/deploy_sparse.py | 8 ++++---- .../how_to/extend_tvm/bring_your_own_datatypes.py | 2 +- gallery/how_to/extend_tvm/low_level_custom_pass.py | 2 +- gallery/how_to/extend_tvm/use_pass_infra.py | 8 ++++---- gallery/how_to/extend_tvm/use_pass_instrument.py | 4 ++-- gallery/how_to/optimize_operators/opt_conv_cuda.py | 2 +- .../optimize_operators/opt_conv_tensorcore.py | 2 +- gallery/how_to/optimize_operators/opt_gemm.py | 4 ++-- .../how_to/tune_with_autotvm/tune_conv2d_cuda.py | 2 +- gallery/how_to/work_with_relay/build_gcn.py | 2 +- gallery/how_to/work_with_relay/using_relay_viz.py | 6 +++--- gallery/how_to/work_with_schedules/extern_op.py | 4 ++-- gallery/how_to/work_with_schedules/intrin_math.py | 2 +- gallery/how_to/work_with_schedules/scan.py | 2 +- gallery/tutorial/auto_scheduler_matmul_x86.py | 4 ++-- gallery/tutorial/autotvm_matmul_x86.py | 14 +++++++------- gallery/tutorial/intro_topi.py | 3 +-- gallery/tutorial/tensor_expr_get_started.py | 8 ++++---- gallery/tutorial/tensor_ir_blitz_course.py | 6 +++--- 19 files changed, 42 insertions(+), 43 deletions(-) diff --git a/gallery/how_to/deploy_models/deploy_sparse.py b/gallery/how_to/deploy_models/deploy_sparse.py index 768a697f45cfc..56a5f1aafd1ce 100644 --- a/gallery/how_to/deploy_models/deploy_sparse.py +++ b/gallery/how_to/deploy_models/deploy_sparse.py @@ -36,11 +36,11 @@ Pruning is a technique primarily used to reduce the parameter size of a model by replacing weight values with 0s. Although many methods exist for choosing which -weights should be set to 0, the most straight forward is by picking the +weights should be set to 0, the most straight forward is by picking the weights with the smallest value. Typically, weights are pruned to a desired sparsity percentage. For example, a 95% sparse model would have only 5% of its weights non-zero. Pruning to very high sparsities often requires -finetuning or full retraining as it tends to be a lossy approximation. +fine-tuning or full retraining as it tends to be a lossy approximation. Although parameter size benefits are quite easy to obtain from a pruned model through simple compression, leveraging sparsity to yield runtime speedups is more complicated. @@ -50,8 +50,8 @@ value and location. The benefit of bunching up pruned weights is that it allows an algorithm such as matrix multiplication to skip entire blocks. It turns out that some degree of *block sparsity* is very important to realizing significant -speedups on most hardware available today. -This is because when loading memory in most CPUs or GPUs, +speedups on most hardware available today. +This is because when loading memory in most CPUs or GPUs, it doesn't save any work to skip reading a single value at a time, instead an entire chunk or tile is read in and executed using something like vectorized instructions. diff --git a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py index 018245609923a..1a48781e24336 100644 --- a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py +++ b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py @@ -313,7 +313,7 @@ def convert_ndarray(dst_dtype, array): print(str(e).split("\n")[-1]) ###################################################################### -# When we attempt to run the model, we get a familiar error telling us that more functions need to be registerd for myfloat. +# When we attempt to run the model, we get a familiar error telling us that more functions need to be registered for myfloat. # # Because this is a neural network, many more operations are required. # Here, we register all the needed functions: diff --git a/gallery/how_to/extend_tvm/low_level_custom_pass.py b/gallery/how_to/extend_tvm/low_level_custom_pass.py index 8f631075429fd..ee96d8220cac3 100644 --- a/gallery/how_to/extend_tvm/low_level_custom_pass.py +++ b/gallery/how_to/extend_tvm/low_level_custom_pass.py @@ -129,7 +129,7 @@ def vectorize(f, mod, ctx): tvm.tir.stmt_functor.post_order_visit(f.body, find_width8) if not loops: - return sf + return f # The last list arugment indicates what kinds of nodes will be transformed. # Thus, in this case only `For` nodes will call `vectorize8` diff --git a/gallery/how_to/extend_tvm/use_pass_infra.py b/gallery/how_to/extend_tvm/use_pass_infra.py index 67cdfdedce0e8..e38383e69011a 100644 --- a/gallery/how_to/extend_tvm/use_pass_infra.py +++ b/gallery/how_to/extend_tvm/use_pass_infra.py @@ -35,7 +35,7 @@ pass infra. For more details about each type of these passes, please refer to the :ref:`pass-infra` -This tutorial mainly demostrates how developers can use the pass infra to perform +This tutorial mainly demonstrates how developers can use the pass infra to perform a certain optimization and create an optimization pipeline for a Relay program. The same approach can be used for tir as well. """ @@ -104,7 +104,7 @@ def example(): print(mod) ############################################################################### -# Some optimizations, such as fusion, are parameteric as well. For example, +# Some optimizations, such as fusion, are parametric as well. For example, # opt level 0 will not allow operators to be fused together. Users can pass the # `fuse_opt_level` to enable this. mod = relay.transform.FuseOps(fuse_opt_level=0)(mod) @@ -127,7 +127,7 @@ def example(): # these issues explicitly by specifying the required passes of each pass and # packing them as a whole to execute. For example, the same passes can now be # applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is -# similiar to `torch.nn.sequential `_ +# similar to `torch.nn.sequential `_ # and `mxnet.gluon.block `_. # For example, `torch.nn.sequential` is used to contain a sequence of PyTorch # `Modules` that will be added to build a network. It focuses on the network @@ -267,7 +267,7 @@ def run_before_pass(self, mod, info): # ------- # This tutorial has covered how we can write and invoke passes in TVM more # conveniently using the pass infra. Different ways of invoking a pass are also -# disucssed. Using :py:class:`tvm.transform.Sequential` can largely help +# discussed. Using :py:class:`tvm.transform.Sequential` can largely help # users to ease the work of handling multiple optimization passes and their # dependencies. In addition, an example is provided to illustrate # how we can debug a pass using the ``PrintIR`` and tracing. diff --git a/gallery/how_to/extend_tvm/use_pass_instrument.py b/gallery/how_to/extend_tvm/use_pass_instrument.py index 3369304a651d3..036aa63e374f0 100644 --- a/gallery/how_to/extend_tvm/use_pass_instrument.py +++ b/gallery/how_to/extend_tvm/use_pass_instrument.py @@ -30,7 +30,7 @@ for collecting timing information (:py:class:`tvm.ir.instrument.PassTimingInstrument`), but an extension mechanism is available via the :py:func:`tvm.instrument.pass_instrument` decorator. -This tutorial demostrates how developers can use ``PassContext`` to instrument +This tutorial demonstrates how developers can use ``PassContext`` to instrument passes. Please also refer to the :ref:`pass-infra`. """ import tvm @@ -314,7 +314,7 @@ def exit_pass_ctx(self): print("Catching", str(ex).split("\n")[-1]) ############################################################################### -# Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass`` +# Exceptions occurred in ``should_run``, ``run_before_pass``, ``run_after_pass`` # are not handled explicitly -- we rely on the context manager (the ``with`` syntax) # to exit ``PassContext`` safely. # diff --git a/gallery/how_to/optimize_operators/opt_conv_cuda.py b/gallery/how_to/optimize_operators/opt_conv_cuda.py index 0ac2c625bf781..3d2caa0d31214 100644 --- a/gallery/how_to/optimize_operators/opt_conv_cuda.py +++ b/gallery/how_to/optimize_operators/opt_conv_cuda.py @@ -97,7 +97,7 @@ # :width: 271px # # In this example, we load both Apad and W into buffer AA and WW, which are -# stored in the shared memory. These bufferes will be later shared by all +# stored in the shared memory. These buffers will be later shared by all # threads within the same thread block to compute the convolution. Each thread # then loads its own part from shared buffer into their local registers, AL and # WL. BL is a local cache of output B, which is also stored in the thread local diff --git a/gallery/how_to/optimize_operators/opt_conv_tensorcore.py b/gallery/how_to/optimize_operators/opt_conv_tensorcore.py index 702e4a777df57..ccfc7b9743aaa 100644 --- a/gallery/how_to/optimize_operators/opt_conv_tensorcore.py +++ b/gallery/how_to/optimize_operators/opt_conv_tensorcore.py @@ -306,7 +306,7 @@ def intrin_func(ins, outs): # *Warp-level Operation* # # Note that all TensorCore instructions are warp-level instructions, which means all 32 threads -# in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the +# in a warp should do this instruction simultaneously. Making threadIdx.x extent=32 is one of the # easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain # TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution. # The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time. diff --git a/gallery/how_to/optimize_operators/opt_gemm.py b/gallery/how_to/optimize_operators/opt_gemm.py index 5d698c612ee8f..920d7a87fabf9 100644 --- a/gallery/how_to/optimize_operators/opt_gemm.py +++ b/gallery/how_to/optimize_operators/opt_gemm.py @@ -312,7 +312,7 @@ s[CC].reorder(ko, mc, ki, nc) s[CC].vectorize(nc) -# TODO: Add separate optimization step to discuss loop unrolloing +# TODO: Add separate optimization step to discuss loop unrolling # unrolling is a loop optimization strategy which can reduce branch # prediction failures and increases the chance of concurrent execution # unroll kfactor loops @@ -390,4 +390,4 @@ # our generated code can achieve 60% of the `numpy` performance with MKL. # Note that the outputs on the web page reflect the running times on a non-exclusive # Docker container, thereby they are *unreliable*. It is highly encouraged to run the -# tutorial by yourself to observe the performance gain acheived by TVM. +# tutorial by yourself to observe the performance gain achieved by TVM. diff --git a/gallery/how_to/tune_with_autotvm/tune_conv2d_cuda.py b/gallery/how_to/tune_with_autotvm/tune_conv2d_cuda.py index ef921563e466f..e3072773bf593 100644 --- a/gallery/how_to/tune_with_autotvm/tune_conv2d_cuda.py +++ b/gallery/how_to/tune_with_autotvm/tune_conv2d_cuda.py @@ -74,7 +74,7 @@ # # If you are familiar with writing cuda schedule, you can find the following # template is very general. Actually this template can be easily modified -# to tune other operators such as depthwise convolution and gemm. +# to tune other operators such as depthwise convolution and GEMM. # In order to fully understand this template, you should be familiar with # the schedule primitives and auto tuning API. You can refer to the above # tutorials and :ref:`autotvm tutorial ` diff --git a/gallery/how_to/work_with_relay/build_gcn.py b/gallery/how_to/work_with_relay/build_gcn.py index d76baec1eec14..fcffbd77ff86b 100644 --- a/gallery/how_to/work_with_relay/build_gcn.py +++ b/gallery/how_to/work_with_relay/build_gcn.py @@ -314,7 +314,7 @@ def prepare_params(g, data): # Compile and run with TVM # ------------------------ # -# Export the weigths from PyTorch model to Python Dict +# Export the weights from PyTorch model to Python Dict model_params = {} for param_tensor in torch_model.state_dict(): model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy() diff --git a/gallery/how_to/work_with_relay/using_relay_viz.py b/gallery/how_to/work_with_relay/using_relay_viz.py index 10e6dab12e245..b0132f40b9b51 100644 --- a/gallery/how_to/work_with_relay/using_relay_viz.py +++ b/gallery/how_to/work_with_relay/using_relay_viz.py @@ -22,7 +22,7 @@ Relay IR module can contain lots of operations. Although an individual operation is usually easy to understand, putting them together can cause -a complicated, hard-to-read graph. Things can get even worse with optimiztion-passes +a complicated, hard-to-read graph. Things can get even worse with optimization-passes coming into play. This utility visualizes an IR module as nodes and edges. It defines a set of interfaces including @@ -89,7 +89,7 @@ # ------------------------------------------- # Sometimes we want to emphasize interested information, or parse things differently for a specific usage. # It is possible to provide customized parsers as long as it obeys the interface. -# Here demostrate how to customize parsers for ``relay.var``. +# Here demonstrate how to customize parsers for ``relay.var``. # We need to implement abstract interface :py:class:`tvm.contrib.relay_viz.interface.VizParser`. class YourAwesomeParser(VizParser): def __init__(self): @@ -131,7 +131,7 @@ def node(self, viz_node): super().node(viz_node) # if it's AwesomeVar, duplicate it. if viz_node.type_name == "AwesomeVar": - duplicated_id = f"duplciated_{viz_node.identity}" + duplicated_id = f"duplicated_{viz_node.identity}" duplicated_type = "double AwesomeVar" super().node(VizNode(duplicated_id, duplicated_type, "")) # connect the duplicated var to the original one diff --git a/gallery/how_to/work_with_schedules/extern_op.py b/gallery/how_to/work_with_schedules/extern_op.py index fb9b2eaf8d13b..a0aa5d72450c0 100644 --- a/gallery/how_to/work_with_schedules/extern_op.py +++ b/gallery/how_to/work_with_schedules/extern_op.py @@ -25,7 +25,7 @@ some of the convolution kernels and define the rest of the stages. TVM supports these black box function calls natively. -Specfically, TVM support all the tensor functions that are DLPack compatible. +Specifically, TVM support all the tensor functions that are DLPack compatible. Which means we can call any function with POD types(pointer, int, float) or pointer to DLTensor as argument. """ @@ -52,7 +52,7 @@ # list of symbolic placeholder for the outputs and returns the executing statement. # # In this case we simply call a registered TVM function, which invokes a CBLAS call. -# TVM does not control internal of the extern array function and treats it as blackbox. +# TVM does not control internal of the extern array function and treats it as black-box. # We can further mix schedulable TVM calls that add a bias term to the result. # n = 1024 diff --git a/gallery/how_to/work_with_schedules/intrin_math.py b/gallery/how_to/work_with_schedules/intrin_math.py index 92383b90a53f9..535563bfb5306 100644 --- a/gallery/how_to/work_with_schedules/intrin_math.py +++ b/gallery/how_to/work_with_schedules/intrin_math.py @@ -26,7 +26,7 @@ These functions are target system dependent and may have different names of different target platforms. In this tutorial, we will learn how we can invoke these target specific functions, and how we can unify -the interface via tvm's intrinsic API. +the interface via TVM's intrinsic API. """ from __future__ import absolute_import, print_function import numpy as np diff --git a/gallery/how_to/work_with_schedules/scan.py b/gallery/how_to/work_with_schedules/scan.py index ba8b5a9f8e06a..3f3d7e91ee1c1 100644 --- a/gallery/how_to/work_with_schedules/scan.py +++ b/gallery/how_to/work_with_schedules/scan.py @@ -60,7 +60,7 @@ # Schedule the Scan Cell # ---------------------- # We can schedule the body of the scan by scheduling the update and -# init part seperately. Note that it is invalid to schedule the +# init part separately. Note that it is invalid to schedule the # first iteration dimension of the update part. # To split on the time iteration, user can schedule on scan_op.scan_axis instead. # diff --git a/gallery/tutorial/auto_scheduler_matmul_x86.py b/gallery/tutorial/auto_scheduler_matmul_x86.py index 9f3a6070ccb23..b9f89f6723c9b 100644 --- a/gallery/tutorial/auto_scheduler_matmul_x86.py +++ b/gallery/tutorial/auto_scheduler_matmul_x86.py @@ -78,13 +78,13 @@ def matmul_add(N, L, M, dtype): # ---------------------- # With the function defined, we can now create the task for the auto_scheduler # to search against. We specify the particular parameters for this matrix -# multiplication, in this case a multiplication of to square matricies of size +# multiplication, in this case a multiplication of two square matrices of size # 1024x1024. We then create a search task with N=L=M=1024 and dtype="float32" # # .. admonition:: Improve performance with custom targets # # In order for TVM to take full advantage of specific hardware platforms, -# you will want to manuall specify your CPU capabilities. For example: +# you will want to manually specify your CPU capabilities. For example: # # - replace ``llvm`` below with ``llvm -mcpu=core-avx2`` to enable AVX2 # - replace ``llvm`` below with ``llvm -mcpu=skylake-avx512`` to enable AVX-512 diff --git a/gallery/tutorial/autotvm_matmul_x86.py b/gallery/tutorial/autotvm_matmul_x86.py index 54581172115d2..b84a6193cde6e 100644 --- a/gallery/tutorial/autotvm_matmul_x86.py +++ b/gallery/tutorial/autotvm_matmul_x86.py @@ -28,7 +28,7 @@ find the optimal schedule. This process is called Auto-Tuning, which helps automate the process of optimizing tensor computation. -This tutorial builds on the previous `tutorial on how to write a matrix +This tutorial builds on the previous :doc:`tutorial on how to write a matrix multiplication using TE `. There are two steps in auto-tuning. @@ -201,7 +201,7 @@ def matmul_v1(N, L, M, dtype): # knob. This is the lowest level API to define the space, and gives an explicit # enumeration of the parameter space to search. However, we also provide # another set of APIs that can make the definition of the search space easier -# and smarter. Where possible, we receomment you use this higher-level API +# and smarter. Where possible, we recommend you use this higher-level API # # In the following example, we use :any:`ConfigSpace.define_split` to define a # split knob. It will enumerate all the possible ways to split an axis and @@ -267,7 +267,7 @@ def matmul(N, L, M, dtype): # Step 2: Use AutoTVM to Optimize the Matrix Multiplication # --------------------------------------------------------- # In Step 1, we wrote a matrix multiplication template that allowed us to -# paramaterize the block size used in the `split` schedule. We can now conduct +# parameterize the block size used in the `split` schedule. We can now conduct # a search over this parameter space. The next step is to pick a tuner to guide # the exploration of this space. # @@ -295,7 +295,7 @@ def matmul(N, L, M, dtype): # # You can choose the tuner according to the size of your space, your time # budget and other factors. For example, if your space is very small (less -# than 1000), a gridsearch tuner or a random tuner is good enough. If your +# than 1000), a grid-search tuner or a random tuner is good enough. If your # space is at the level of 10^9 (this is the space size of a conv2d operator on # CUDA GPU), XGBoostTuner can explore more efficiently and find better configs. @@ -342,7 +342,7 @@ def matmul(N, L, M, dtype): ################################################################################ # With tuning completed, we can choose the configuration from the log file that # has the best measured performance and compile the schedule with the -# corresponding parameters. We also do a quick verfication that the schedule is +# corresponding parameters. We also do a quick verification that the schedule is # producing correct answers. We can call the function :code:`matmul` directly # under the :any:`autotvm.apply_history_best` context. When we call this # function, it will query the dispatch context with its argument and get the @@ -371,7 +371,7 @@ def matmul(N, L, M, dtype): # TVM to search a parameter space and choose optimized schedule configurations. # To gain a deeper understanding of how this works, we recommend expanding on # this example by adding new search parameters to the schedule based on -# schedule operations demonstated in the `Getting Started With Tensor +# schedule operations demonstrated in the :ref: `Getting Started With Tensor # Expressions _` tutorial. In the upcoming sections, we -# will demonstate the AutoScheduler, a method for TVM to optimize common +# will demonstrate the AutoScheduler, a method for TVM to optimize common # operators without the need for the user to provide a user-defined template. diff --git a/gallery/tutorial/intro_topi.py b/gallery/tutorial/intro_topi.py index dad8c53bf4ae3..17fa3ff370e54 100644 --- a/gallery/tutorial/intro_topi.py +++ b/gallery/tutorial/intro_topi.py @@ -23,9 +23,8 @@ This is an introductory tutorial to TVM Operator Inventory (TOPI). TOPI provides numpy-style generic operations and schedules with higher abstractions than TVM. -In this tutorial, we will see how TOPI can save us from writing boilerplates code in TVM. +In this tutorial, we will see how TOPI can save us from writing boilerplate code in TVM. """ -from __future__ import absolute_import, print_function import tvm import tvm.testing diff --git a/gallery/tutorial/tensor_expr_get_started.py b/gallery/tutorial/tensor_expr_get_started.py index 7d8c0d781a3f3..25ea4e8a55ee5 100644 --- a/gallery/tutorial/tensor_expr_get_started.py +++ b/gallery/tutorial/tensor_expr_get_started.py @@ -187,8 +187,8 @@ def evaluate_addition(func, target, optimization, log): evaluate_addition(fadd, tgt, "naive", log=log) ################################################################################ -# Updating the Schedule to Use Paralleism -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Updating the Schedule to Use Parallelism +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Now that we've illustrated the fundamentals of TE, let's go deeper into what # schedules do, and how they can be used to optimize tensor expressions for @@ -754,7 +754,7 @@ def evaluate_operation(s, vars, target, name, optimization, log): # regular but discontinuous. We expect that after some transformation we can # get a continuous access pattern. By reordering a ``[16][16]`` array to a # ``[16/4][16][4]`` array the access pattern of B will be sequential when -# grabing the corresponding value from the packed array. +# grabbing the corresponding value from the packed array. # # To accomplish this, we are going to have to start with a new default # schedule, taking into account the new packing of B. It's worth taking a @@ -889,7 +889,7 @@ def evaluate_operation(s, vars, target, name, optimization, log): # have from this introduction to TE, we can now begin to explore how TVM can # automate the schedule optimization process. # -# This tutorial provided a walkthrough of TVM Tensor Expresstion (TE) workflow +# This tutorial provided a walk-through of TVM Tensor Expression (TE) workflow # using a vector add and a matrix multiplication examples. The general workflow # is # diff --git a/gallery/tutorial/tensor_ir_blitz_course.py b/gallery/tutorial/tensor_ir_blitz_course.py index e9a0801f34a81..11edc7ae9f3b9 100644 --- a/gallery/tutorial/tensor_ir_blitz_course.py +++ b/gallery/tutorial/tensor_ir_blitz_course.py @@ -25,7 +25,7 @@ - An implementation for transforming and optimizing programs on various hardware backends. -- An abstraction for automatic tensorized program optimization. +- An abstraction for automatic _tensorized_ program optimization. """ @@ -145,7 +145,7 @@ def main(a: T.handle, b: T.handle): # sequence of schedule primitives will help to improve the performance. And at last, we can lower # and build it into a runnable module. # -# Here we just demostrate a very simple tranformation. First we create schedule on the input `ir_module`. +# Here we just demonstrate a very simple transformation. First we create schedule on the input `ir_module`. sch = tvm.tir.Schedule(ir_module) print(type(sch)) @@ -155,7 +155,7 @@ def main(a: T.handle, b: T.handle): # Get block by its name block_b = sch.get_block("B") -# Get loops surronding the block +# Get loops surrounding the block (i,) = sch.get_loops(block_b) # Tile the loop nesting. i_0, i_1, i_2 = sch.split(i, factors=[2, 2, 2]) From a95a820cfaa0fa5d83e2f6a7c304c61e0de782c1 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Wed, 8 Jun 2022 13:41:02 +0800 Subject: [PATCH 065/181] [DNNL] Fix end of line in test_dnnl UT file (#11560) --- tests/python/contrib/test_dnnl.py | 2072 ++++++++++++++--------------- 1 file changed, 1036 insertions(+), 1036 deletions(-) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 19ac183d66dfe..babfad4a0c8c7 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -1,1036 +1,1036 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest -import itertools -import numpy as np -import sys -import subprocess - -import tvm -from tvm import relay -from tvm.relay import transform -from tvm.relay.build_module import bind_params_by_name -from tvm.relay.testing.temp_op_attr import TempOpAttr -from tvm.relay.op.contrib import dnnl -import tvm.testing - - -has_dnnl_codegen = pytest.mark.skipif( - not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available" -) - -run_module = tvm.testing.parameter( - pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), - pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), - ids=["compile", "run"], -) - -_bf16_supported = None - - -def bf16_supported(): - global _bf16_supported - if _bf16_supported is None: - _bf16_supported = False - if sys.platform.startswith("darwin"): - cpu_info = subprocess.check_output("sysctl -a", shell=True).strip().decode() - for line in cpu_info.split("\n"): - if line.startswith("hw.optional.avx512f"): - _bf16_supported = bool(line.split(":", 1)[1]) - elif sys.platform.startswith("linux"): - _bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read() - return _bf16_supported - - -def partition_for_dnnl(mod, params=None, alter_layout=True): - """Partition the graph greedily offloading supported operators to DNNL. - - Parameters - ---------- - mod : Module - The module to run passes on. - params : Optional[Dict[str, NDArray]] - Constant input parameters. - Returns - ------- - mod : Module - Annotated and partitioned module. - """ - if params: - mod["main"] = bind_params_by_name(mod["main"], params) - - with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv): - with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", dnnl.legalize_group_conv): - seq = tvm.transform.Sequential( - [ - transform.CanonicalizeOps(), - transform.InferType(), - transform.SimplifyInference(), - transform.FoldConstant(), - transform.FoldScaleAxis(), - # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` - transform.SimplifyExpr(), - transform.FoldConstant(), - # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW` - transform.Legalize(), - transform.FoldConstant(), - ] - ) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - if alter_layout: - with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv): - with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv): - with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", dnnl.alter_conv): - with TempOpAttr( - "nn.conv2d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose - ): - with TempOpAttr( - "nn.conv3d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose - ): - alter_layout_seq = tvm.transform.Sequential( - [ - transform.AlterOpLayout(), - transform.FoldConstant(), - ] - ) - with tvm.transform.PassContext(opt_level=3): - mod = alter_layout_seq(mod) - - byoc_seq = tvm.transform.Sequential( - [ - transform.MergeComposite(dnnl.pattern_table()), - transform.AnnotateTarget("dnnl"), - transform.MergeCompilerRegions(), - transform.PartitionGraph(), - ] - ) - with tvm.transform.PassContext(opt_level=3): - mod = byoc_seq(mod) - mod = dnnl.prune_dnnl_subgraphs(mod) - return mod - - -def vmobj_to_list(o): - if isinstance(o, tvm.nd.NDArray): - o_np = o.numpy() - if o_np.dtype == np.uint16: - o_np = np.left_shift(o_np.astype("uint32"), 16).view("= 1 - - dev = tvm.cpu() - result_dict = dict() - for mode in ["graph", "vm"]: - configs = [ - (False, False, False), - (True, False, False), - (True, True, False), - ] - if test_bf16 and bf16_supported(): - configs += [(True, False, True), (True, True, True)] - for use_dnnl, alter_layout, use_bf16 in configs: - result_key = ( - mode - + ("_dnnl" if use_dnnl else "") - + ("_layout" if alter_layout else "") - + ("_bf16" if use_bf16 else "_fp32") - ) - processed_mod = mod - if use_bf16: - processed_mod = relay.transform.ToMixedPrecision("bfloat16")(processed_mod) - if tvm.ir.structural_equal(processed_mod, mod): - print("can not convert to bfloat16, skipping...") - continue - if use_dnnl: - processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) - check_dnnl_used(processed_mod) - - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=processed_mod, device=dev, target=target - ).evaluate() - if run_module: - if isinstance(input, dict): - result_dict[result_key] = func(**input, **params) - else: - result_dict[result_key] = func(input, **params) - - if run_module: - assert_result_dict_holds(result_dict) - - -def run_and_verify_func( - config, run_module, subgraph_num=None, target="llvm", dtype="float32", test_bf16=True -): - """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs. - Parameters - ---------- - config : Tuple[relay.Function, Dict[str, NDArray], List[str]] - A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and - 3) A list of which vars should be considered params. - run_module: bool - If True, the built module will be run after being compiled. - """ - f, input_shapes, is_param = config - params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype) for x in is_param} - input_dict = { - k: np.random.uniform(-1, 1, v).astype(dtype) - for k, v in input_shapes.items() - if k not in is_param - } - run_and_verify( - f, - input_dict, - params, - subgraph_num=subgraph_num, - target=target, - run_module=run_module, - test_bf16=test_bf16, - ) - - -def get_conv1d( - x_shape=((1, 3, 224)), - k_shape=(16, 3, 3), - groups=1, - padding=(1, 1), - strides=(1), - dilation=(1), - channels=None, - activation=None, - dtype="float32", -): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) - out = relay.nn.conv1d( - x, - kernel, - kernel_size=k_shape[2:3], - groups=groups, - padding=padding, - strides=strides, - dilation=dilation, - channels=k_shape[0], - ) - dic = {"x": x_shape, "kernel": k_shape} - param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"): - conv, dic, param_lst = get_conv1d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) - bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(conv, bias) - dic["bias"] = (k_shape[0],) - param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"): - conv1d_bias, dic, param_lst = get_conv1d_bias(x_shape, k_shape, dtype=dtype) - beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) - gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) - moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) - moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) - conv1d_bias_bn, _, _ = relay.nn.batch_norm( - conv1d_bias, - gamma=gamma, - beta=beta, - moving_mean=moving_mean, - moving_var=moving_var, - axis=1, - center=True, - scale=True, - epsilon=1e-5, - ) - return relay.nn.relu(conv1d_bias_bn), dic, param_lst - - -def get_conv2d( - x_shape=(1, 32, 8, 8), - k_shape=(16, 32, 3, 3), - groups=1, - padding=(0, 0), - strides=(1, 1), - dilation=(1, 1), - activation=None, - dtype="float32", -): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) - out = relay.nn.conv2d( - x, - kernel, - kernel_size=k_shape[2:4], - groups=groups, - padding=padding, - strides=strides, - dilation=dilation, - channels=k_shape[0], - ) - dic = {"x": x_shape, "kernel": k_shape} - param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv2d_transpose( - x_shape=(1, 32, 8, 8), - k_shape=(32, 16, 3, 3), - groups=1, - padding=(0, 0), - strides=(1, 1), - activation=None, - dtype="float32", -): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) - out = relay.nn.conv2d_transpose( - x, - kernel, - channels=k_shape[1] * groups, - kernel_size=k_shape[2:4], - groups=groups, - padding=padding, - strides=strides, - ) - dic = {"x": x_shape, "kernel": k_shape} - param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv2d_weights_const( - x_shape=(1, 32, 8, 8), - k_shape=(16, 32, 3, 3), - groups=1, - padding=(0, 0), - strides=(1, 1), - dilation=(1, 1), - dtype="float32", -): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) - out = relay.nn.conv2d( - x, - kernel, - channels=k_shape[0], - kernel_size=k_shape[2:4], - groups=groups, - padding=padding, - strides=strides, - dilation=dilation, - ) - dic = {"x": x_shape} - param_lst = [] - return out, dic, param_lst - - -def get_conv2d_bias( - x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), activation=None, dtype="float32" -): - conv, dic, param_lst = get_conv2d_weights_const(x_shape=x_shape, k_shape=k_shape, dtype=dtype) - bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(conv, bias) - dic["bias"] = (k_shape[0],) - param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv2d_transpose_bias( - x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), activation=None, dtype="float32" -): - conv, dic, param_lst = get_conv2d_transpose(x_shape=x_shape, k_shape=k_shape, dtype=dtype) - bias = relay.var("bias", shape=(k_shape[1],), dtype=dtype) - out = relay.nn.bias_add(conv, bias) - dic["bias"] = (k_shape[1],) - param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): - conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) - beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) - gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) - moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) - moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) - conv2d_bias_bn, _, _ = relay.nn.batch_norm( - conv2d_bias, - gamma=gamma, - beta=beta, - moving_mean=moving_mean, - moving_var=moving_var, - axis=1, - center=True, - scale=True, - epsilon=1e-5, - ) - return relay.nn.relu(conv2d_bias_bn), dic, param_lst - - -def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): - conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) - sum_data = relay.const(np.random.randint(x_shape).astype(dtype)) - conv2d_bias_sum = relay.add(sum_data, conv2d_bias) - return relay.nn.relu(conv2d_bias_sum), dic, param_lst - - -def get_conv3d( - x_shape=(1, 32, 8, 8, 8), - k_shape=(16, 32, 3, 3, 3), - groups=1, - padding=(0, 0, 0), - strides=(1, 1, 1), - dilation=(1, 1, 1), - activation=None, - dtype="float32", -): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) - out = relay.nn.conv3d( - x, - kernel, - channels=k_shape[0], - kernel_size=k_shape[2:], - groups=groups, - padding=padding, - strides=strides, - dilation=dilation, - ) - dic = {"x": x_shape, "kernel": k_shape} - param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv3d_transpose( - x_shape=(1, 32, 8, 8, 8), - k_shape=(32, 16, 3, 3, 3), - groups=1, - padding=(0, 0, 0), - strides=(1, 1, 1), - output_padding=(0, 0, 0), - activation=None, - dtype="float32", - data_layout="NCDHW", - kernel_layout="OIDHW", -): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) - out = relay.nn.conv3d_transpose( - x, - kernel, - channels=k_shape[1], - kernel_size=k_shape[2:5], - groups=groups, - padding=padding, - strides=strides, - output_padding=output_padding, - data_layout=data_layout, - kernel_layout=kernel_layout, - ) - dic = {"x": x_shape, "kernel": k_shape} - param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv3d_bias( - x_shape=(1, 32, 8, 8, 8), k_shape=(16, 32, 3, 3, 3), activation=None, dtype="float32" -): - conv, dic, param_lst = get_conv3d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) - bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(conv, bias) - dic["bias"] = (k_shape[0],) - param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_conv3d_transpose_bias( - x_shape=(1, 32, 8, 8, 8), k_shape=(32, 16, 3, 3, 3), activation=None, dtype="float32" -): - conv, dic, param_lst = get_conv3d_transpose(x_shape=x_shape, k_shape=k_shape, dtype=dtype) - bias = relay.var("bias", shape=(k_shape[1],), dtype=dtype) - out = relay.nn.bias_add(conv, bias) - dic["bias"] = (k_shape[1],) - param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst - - -def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): - x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) - out = relay.nn.dense(x, kernel, units=k_shape[0]) - dic = {"x": x_shape, "kernel": k_shape} - param_lst = ["kernel"] - return out, dic, param_lst - - -def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): - dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype) - bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(dense, bias) - dic["bias"] = (k_shape[0],) - param_lst += ["bias"] - return out, dic, param_lst - - -def test_dnnl_not_compatible(run_module, target="llvm", dtype="float32"): - xshape = (1, 32, 14, 14) - x_data = np.random.uniform(-1, 1, xshape).astype(dtype) - - x = relay.var("x", shape=(xshape), dtype=dtype) - y = relay.add(x, x) - z = relay.cast(relay.cast(y, "int32"), "float32") - out = relay.nn.relu(z) - f = relay.Function([x], out) - mod = tvm.IRModule() - mod["main"] = f - mod = partition_for_dnnl(mod) - for mode in ["graph", "vm"]: - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor(mode, mod=mod, device=tvm.cpu(0), target=target).evaluate() - if run_module: - results = func(x_data) - - -def test_multiple_outputs(run_module, dtype="float32"): - def get_graph(): - x = relay.var("x", shape=(1, 3), dtype=dtype) - y = relay.var("y", shape=(1, 3), dtype=dtype) - z = relay.add(x, y) - w = relay.add(z, y) - out = relay.Tuple((z, w)) - f = tvm.IRModule.from_expr(out) - return f, {"x": (1, 3), "y": (1, 3)}, [] - - run_and_verify_func(get_graph(), run_module=run_module, dtype=dtype) - - -def test_elementwise(run_module, dtype="float32"): - def get_graph(op, x_shape=(1, 8, 3, 3)): - x = relay.var("x", shape=(x_shape), dtype=dtype) - out = op(x) - f = tvm.IRModule.from_expr(out) - return f, {"x": x_shape}, [] - - for op in [ - relay.abs, - relay.exp, - relay.log, - relay.sqrt, - relay.nn.relu, - relay.tanh, - relay.sigmoid, - ]: - run_and_verify_func(get_graph(op), run_module=run_module) - - -def test_clip(run_module, dtype="float32"): - def get_graph(x_shape=(1, 8, 3, 3)): - x = relay.var("x", shape=(x_shape), dtype=dtype) - out = relay.clip(x, a_min=-0.2, a_max=0.4) - f = tvm.IRModule.from_expr(out) - return f, {"x": x_shape}, [] - - run_and_verify_func(get_graph(), run_module=run_module) - - -def test_leaky_relu(run_module, dtype="float32"): - def get_graph(x_shape=(1, 8, 3, 3)): - x = relay.var("x", shape=(x_shape), dtype=dtype) - out = relay.nn.leaky_relu(x, alpha=0.1) - f = tvm.IRModule.from_expr(out) - return f, {"x": x_shape}, [] - - run_and_verify_func(get_graph(), run_module=run_module) - - -def test_softmax(run_module, dtype="float32"): - def get_graph(x_shape, axis): - x = relay.var("x", shape=(x_shape), dtype=dtype) - out = relay.nn.softmax(x, axis=axis) - f = tvm.IRModule.from_expr(out) - return f, {"x": x_shape}, [] - - run_and_verify_func(get_graph((1, 1000), axis=1), run_module=run_module) - run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) - run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) - run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) - - -def test_conv1d(run_module, dtype="float32"): - conv1d, dic, param_lst = get_conv1d(channels=16, dtype=dtype) - conv1d = tvm.IRModule.from_expr(conv1d) - config = conv1d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - x_shape = (1, 32, 224) - k_shape = (16, 32, 3) - conv1d_bias, dic, param_lst = get_conv1d(x_shape, k_shape, dtype=dtype) - conv1d_bias = tvm.IRModule.from_expr(conv1d_bias) - config = conv1d_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv1d_pattern(run_module, dtype="float32"): - x_shape = (1, 3, 224) - k_shape = (16, 3, 3) - activation_lst = [None, "relu", "tanh", "sigmoid"] - for a in activation_lst: - conv1d, dic, param_lst = get_conv1d(x_shape, k_shape, activation=a, dtype=dtype) - conv1d = tvm.IRModule.from_expr(conv1d) - config = conv1d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv1d_bias, dic, param_lst = get_conv1d_bias(x_shape, k_shape, activation=a, dtype=dtype) - conv1d_bias = tvm.IRModule.from_expr(conv1d_bias) - config = conv1d_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv2d(run_module, dtype="float32"): - x_shape = (1, 32, 8, 8) - for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 2, 3, 3), 16)]: - for padding in [(0, 0), (1, 1)]: - for strides in [(1, 1), (2, 2)]: - for dilation in [(1, 1), (2, 2)]: - conv2d, dic, param_lst = get_conv2d( - x_shape=x_shape, - k_shape=k_shape, - groups=groups, - padding=padding, - strides=strides, - dilation=dilation, - dtype=dtype, - ) - conv2d = tvm.IRModule.from_expr(conv2d) - config = conv2d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv2d_weights_const(run_module, dtype="float32"): - x_shape = (1, 32, 8, 8) - k_shape = (16, 32, 3, 3) - conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype) - conv2d = tvm.IRModule.from_expr(conv2d) - config = conv2d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - x_shape = (1, 3, 8, 8) - k_shape = (16, 3, 3, 3) - conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype) - conv2d = tvm.IRModule.from_expr(conv2d) - config = conv2d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv2d_pattern(run_module, dtype="float32"): - x_shape = (1, 32, 8, 8) - k_shape = (16, 32, 3, 3) - activation_lst = [None, "relu", "tanh", "sigmoid"] - for a in activation_lst: - conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype) - conv2d = tvm.IRModule.from_expr(conv2d) - config = conv2d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=a, dtype=dtype) - conv2d_bias = tvm.IRModule.from_expr(conv2d_bias) - config = conv2d_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype) - conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu) - config = conv2d_bias_bn_relu, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype) - conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu) - config = conv2d_bias_bn_relu, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv2d_transpose(run_module, dtype="float32"): - x_shape = (1, 32, 8, 8) - for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]: - for padding in [(0, 0), (1, 1)]: - for strides in [(1, 1), (2, 2)]: - conv2d_transpose, dic, param_lst = get_conv2d_transpose( - x_shape=x_shape, - k_shape=k_shape, - groups=groups, - padding=padding, - strides=strides, - dtype=dtype, - ) - conv2d_transpose = tvm.IRModule.from_expr(conv2d_transpose) - config = conv2d_transpose, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv2d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] - for a in activation_lst: - conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype) - conv2d = tvm.IRModule.from_expr(conv2d) - config = conv2d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv2d_bias, dic, param_lst = get_conv2d_transpose_bias(activation=a, dtype=dtype) - conv2d_bias = tvm.IRModule.from_expr(conv2d_bias) - config = conv2d_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv3d(run_module, dtype="float32"): - conv3d, dic, param_lst = get_conv3d(dtype=dtype) - conv3d = tvm.IRModule.from_expr(conv3d) - config = conv3d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv3d, dic, param_lst = get_conv3d(padding=(0, 0, 0, 1, 1, 1), dtype=dtype) - conv3d = tvm.IRModule.from_expr(conv3d) - config = conv3d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv3d, dic, param_lst = get_conv3d( - x_shape=(1, 3, 8, 8, 8), k_shape=(16, 3, 3, 3, 3), dtype=dtype - ) - conv3d = tvm.IRModule.from_expr(conv3d) - config = conv3d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv3d_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] - for a in activation_lst: - conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype) - conv3d = tvm.IRModule.from_expr(conv3d) - config = conv3d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv3d_bias, dic, param_lst = get_conv3d_bias(activation=a, dtype=dtype) - conv3d_bias = tvm.IRModule.from_expr(conv3d_bias) - config = conv3d_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv3d_transpose(run_module, dtype="float32"): - conv3d_transpose, dic, param_lst = get_conv3d_transpose(dtype=dtype) - conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) - config = conv3d_transpose, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv3d_transpose, dic, param_lst = get_conv3d_transpose(strides=(2, 2, 2), dtype=dtype) - conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) - config = conv3d_transpose, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv3d_transpose, dic, param_lst = get_conv3d_transpose( - strides=(2, 2, 2), output_padding=(1, 1, 1), dtype=dtype - ) - conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) - config = conv3d_transpose, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_conv3d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] - for a in activation_lst: - conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype) - conv3d = tvm.IRModule.from_expr(conv3d) - config = conv3d, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv3d_bias, dic, param_lst = get_conv3d_transpose_bias(activation=a, dtype=dtype) - conv3d_bias = tvm.IRModule.from_expr(conv3d_bias) - config = conv3d_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_dense(run_module, dtype="float32"): - x_shape = (1, 16) - k_shape = (32, 16) - - dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype) - dense = tvm.IRModule.from_expr(dense) - config = dense, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - dense, dic, param_lst = get_dense(x_shape, k_shape=(1, 16), dtype=dtype) - dense = tvm.IRModule.from_expr(dense) - config = dense, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_dense_pattern(run_module, dtype="float32"): - x_shape = (1, 16) - k_shape = (32, 16) - - dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype) - dense = tvm.IRModule.from_expr(dense) - config = dense, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, dtype=dtype) - dense_bias = tvm.IRModule.from_expr(dense_bias) - config = dense_bias, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - -def test_pool2d(run_module, dtype="float32"): - def get_graph( - op, - x_shape=(1, 3, 32, 32), - pool_size=(2, 2), - strides=(2, 2), - padding=(0, 0), - ceil_mode=False, - count_include_pad=None, - ): - x = relay.var("x", shape=(x_shape), dtype=dtype) - if count_include_pad is not None: - out = op( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ) - else: - out = op( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - ) - out = tvm.IRModule.from_expr(out) - return out, {"x": x_shape}, [] - - for pool_size in [(2, 2), (3, 3)]: - for strides in [(1, 1), (2, 2)]: - for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]: - for ceil_mode in [False]: - # Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling" - if pool_size == (2, 2) and padding == (0, 0, 1, 1): - continue - for count_include_pad in [False, True]: - # Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding" - if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)): - continue - run_and_verify_func( - get_graph( - relay.nn.avg_pool2d, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ), - run_module=run_module, - ) - run_and_verify_func( - get_graph( - relay.nn.max_pool2d, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - ), - run_module=run_module, - ) - - -def test_pool3d(run_module, dtype="float32"): - def get_graph( - op, - x_shape=(1, 3, 8, 32, 32), - pool_size=(2, 2, 2), - strides=(2, 2, 2), - padding=(0, 0, 0), - ceil_mode=False, - count_include_pad=None, - dtype="float32", - ): - x = relay.var("x", shape=(x_shape), dtype=dtype) - if count_include_pad is not None: - out = op( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ) - else: - out = op( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - ) - out = tvm.IRModule.from_expr(out) - return out, {"x": x_shape}, [] - - run_and_verify_func(get_graph(relay.nn.avg_pool3d), run_module=run_module) - run_and_verify_func(get_graph(relay.nn.max_pool3d), run_module=run_module) - run_and_verify_func( - get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1)), run_module=run_module - ) - run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module) - - -def test_prune_dnnl_subgraph(run_module): - """In this test, OP "add" should be offloaded from dnnl codegen.""" - - def get_graph(): - x1 = relay.var("x1", shape=(1, 32, 56, 56)) - x2 = relay.var("x2", shape=(1, 32, 56, 56)) - bias = relay.var("bias", shape=(32,)) - weight = relay.var("weight", shape=(32, 32, 3, 3)) - y = relay.nn.conv2d( - x1, - weight, - channels=32, - kernel_size=(3, 3), - padding=(1, 1), - ) - y = relay.nn.bias_add(y, bias) - y = relay.nn.relu(y) - y = relay.nn.global_max_pool2d(y) - y = relay.add(y, x2) - dic = { - "x1": (1, 32, 56, 56), - "x2": (1, 32, 56, 56), - "weight": (32, 32, 3, 3), - "bias": (32,), - } - param_lst = ["weight", "bias"] - out = tvm.IRModule.from_expr(y) - return out, dic, param_lst - - run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False) - - -if __name__ == "__main__": - tvm.testing.main() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import itertools +import numpy as np +import sys +import subprocess + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.op.contrib import dnnl +import tvm.testing + + +has_dnnl_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available" +) + +run_module = tvm.testing.parameter( + pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), + pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm.marks()]), + ids=["compile", "run"], +) + +_bf16_supported = None + + +def bf16_supported(): + global _bf16_supported + if _bf16_supported is None: + _bf16_supported = False + if sys.platform.startswith("darwin"): + cpu_info = subprocess.check_output("sysctl -a", shell=True).strip().decode() + for line in cpu_info.split("\n"): + if line.startswith("hw.optional.avx512f"): + _bf16_supported = bool(line.split(":", 1)[1]) + elif sys.platform.startswith("linux"): + _bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read() + return _bf16_supported + + +def partition_for_dnnl(mod, params=None, alter_layout=True): + """Partition the graph greedily offloading supported operators to DNNL. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + Returns + ------- + mod : Module + Annotated and partitioned module. + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv): + with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", dnnl.legalize_group_conv): + seq = tvm.transform.Sequential( + [ + transform.CanonicalizeOps(), + transform.InferType(), + transform.SimplifyInference(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` + transform.SimplifyExpr(), + transform.FoldConstant(), + # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW` + transform.Legalize(), + transform.FoldConstant(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + if alter_layout: + with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv): + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv): + with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", dnnl.alter_conv): + with TempOpAttr( + "nn.conv2d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose + ): + with TempOpAttr( + "nn.conv3d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose + ): + alter_layout_seq = tvm.transform.Sequential( + [ + transform.AlterOpLayout(), + transform.FoldConstant(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = alter_layout_seq(mod) + + byoc_seq = tvm.transform.Sequential( + [ + transform.MergeComposite(dnnl.pattern_table()), + transform.AnnotateTarget("dnnl"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = byoc_seq(mod) + mod = dnnl.prune_dnnl_subgraphs(mod) + return mod + + +def vmobj_to_list(o): + if isinstance(o, tvm.nd.NDArray): + o_np = o.numpy() + if o_np.dtype == np.uint16: + o_np = np.left_shift(o_np.astype("uint32"), 16).view("= 1 + + dev = tvm.cpu() + result_dict = dict() + for mode in ["graph", "vm"]: + configs = [ + (False, False, False), + (True, False, False), + (True, True, False), + ] + if test_bf16 and bf16_supported(): + configs += [(True, False, True), (True, True, True)] + for use_dnnl, alter_layout, use_bf16 in configs: + result_key = ( + mode + + ("_dnnl" if use_dnnl else "") + + ("_layout" if alter_layout else "") + + ("_bf16" if use_bf16 else "_fp32") + ) + processed_mod = mod + if use_bf16: + processed_mod = relay.transform.ToMixedPrecision("bfloat16")(processed_mod) + if tvm.ir.structural_equal(processed_mod, mod): + print("can not convert to bfloat16, skipping...") + continue + if use_dnnl: + processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) + check_dnnl_used(processed_mod) + + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=processed_mod, device=dev, target=target + ).evaluate() + if run_module: + if isinstance(input, dict): + result_dict[result_key] = func(**input, **params) + else: + result_dict[result_key] = func(input, **params) + + if run_module: + assert_result_dict_holds(result_dict) + + +def run_and_verify_func( + config, run_module, subgraph_num=None, target="llvm", dtype="float32", test_bf16=True +): + """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs. + Parameters + ---------- + config : Tuple[relay.Function, Dict[str, NDArray], List[str]] + A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and + 3) A list of which vars should be considered params. + run_module: bool + If True, the built module will be run after being compiled. + """ + f, input_shapes, is_param = config + params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype) for x in is_param} + input_dict = { + k: np.random.uniform(-1, 1, v).astype(dtype) + for k, v in input_shapes.items() + if k not in is_param + } + run_and_verify( + f, + input_dict, + params, + subgraph_num=subgraph_num, + target=target, + run_module=run_module, + test_bf16=test_bf16, + ) + + +def get_conv1d( + x_shape=((1, 3, 224)), + k_shape=(16, 3, 3), + groups=1, + padding=(1, 1), + strides=(1), + dilation=(1), + channels=None, + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv1d( + x, + kernel, + kernel_size=k_shape[2:3], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + channels=k_shape[0], + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"): + conv, dic, param_lst = get_conv1d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"): + conv1d_bias, dic, param_lst = get_conv1d_bias(x_shape, k_shape, dtype=dtype) + beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) + gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) + moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) + moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) + conv1d_bias_bn, _, _ = relay.nn.batch_norm( + conv1d_bias, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=1, + center=True, + scale=True, + epsilon=1e-5, + ) + return relay.nn.relu(conv1d_bias_bn), dic, param_lst + + +def get_conv2d( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv2d( + x, + kernel, + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + channels=k_shape[0], + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv2d_transpose( + x_shape=(1, 32, 8, 8), + k_shape=(32, 16, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv2d_transpose( + x, + kernel, + channels=k_shape[1] * groups, + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv2d_weights_const( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) + out = relay.nn.conv2d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + dic = {"x": x_shape} + param_lst = [] + return out, dic, param_lst + + +def get_conv2d_bias( + x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv2d_weights_const(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv2d_transpose_bias( + x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv2d_transpose(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[1],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[1],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): + conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) + beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) + gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) + moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) + moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) + conv2d_bias_bn, _, _ = relay.nn.batch_norm( + conv2d_bias, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=1, + center=True, + scale=True, + epsilon=1e-5, + ) + return relay.nn.relu(conv2d_bias_bn), dic, param_lst + + +def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): + conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) + sum_data = relay.const(np.random.randint(x_shape).astype(dtype)) + conv2d_bias_sum = relay.add(sum_data, conv2d_bias) + return relay.nn.relu(conv2d_bias_sum), dic, param_lst + + +def get_conv3d( + x_shape=(1, 32, 8, 8, 8), + k_shape=(16, 32, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + dilation=(1, 1, 1), + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) + out = relay.nn.conv3d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv3d_transpose( + x_shape=(1, 32, 8, 8, 8), + k_shape=(32, 16, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + output_padding=(0, 0, 0), + activation=None, + dtype="float32", + data_layout="NCDHW", + kernel_layout="OIDHW", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) + out = relay.nn.conv3d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:5], + groups=groups, + padding=padding, + strides=strides, + output_padding=output_padding, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv3d_bias( + x_shape=(1, 32, 8, 8, 8), k_shape=(16, 32, 3, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv3d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv3d_transpose_bias( + x_shape=(1, 32, 8, 8, 8), k_shape=(32, 16, 3, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv3d_transpose(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[1],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[1],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.dense(x, kernel, units=k_shape[0]) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + return out, dic, param_lst + + +def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): + dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(dense, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + return out, dic, param_lst + + +def test_dnnl_not_compatible(run_module, target="llvm", dtype="float32"): + xshape = (1, 32, 14, 14) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.add(x, x) + z = relay.cast(relay.cast(y, "int32"), "float32") + out = relay.nn.relu(z) + f = relay.Function([x], out) + mod = tvm.IRModule() + mod["main"] = f + mod = partition_for_dnnl(mod) + for mode in ["graph", "vm"]: + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor(mode, mod=mod, device=tvm.cpu(0), target=target).evaluate() + if run_module: + results = func(x_data) + + +def test_multiple_outputs(run_module, dtype="float32"): + def get_graph(): + x = relay.var("x", shape=(1, 3), dtype=dtype) + y = relay.var("y", shape=(1, 3), dtype=dtype) + z = relay.add(x, y) + w = relay.add(z, y) + out = relay.Tuple((z, w)) + f = tvm.IRModule.from_expr(out) + return f, {"x": (1, 3), "y": (1, 3)}, [] + + run_and_verify_func(get_graph(), run_module=run_module, dtype=dtype) + + +def test_elementwise(run_module, dtype="float32"): + def get_graph(op, x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = op(x) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + for op in [ + relay.abs, + relay.exp, + relay.log, + relay.sqrt, + relay.nn.relu, + relay.tanh, + relay.sigmoid, + ]: + run_and_verify_func(get_graph(op), run_module=run_module) + + +def test_clip(run_module, dtype="float32"): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = relay.clip(x, a_min=-0.2, a_max=0.4) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(), run_module=run_module) + + +def test_leaky_relu(run_module, dtype="float32"): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = relay.nn.leaky_relu(x, alpha=0.1) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(), run_module=run_module) + + +def test_softmax(run_module, dtype="float32"): + def get_graph(x_shape, axis): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = relay.nn.softmax(x, axis=axis) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 1000), axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) + + +def test_conv1d(run_module, dtype="float32"): + conv1d, dic, param_lst = get_conv1d(channels=16, dtype=dtype) + conv1d = tvm.IRModule.from_expr(conv1d) + config = conv1d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + x_shape = (1, 32, 224) + k_shape = (16, 32, 3) + conv1d_bias, dic, param_lst = get_conv1d(x_shape, k_shape, dtype=dtype) + conv1d_bias = tvm.IRModule.from_expr(conv1d_bias) + config = conv1d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv1d_pattern(run_module, dtype="float32"): + x_shape = (1, 3, 224) + k_shape = (16, 3, 3) + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv1d, dic, param_lst = get_conv1d(x_shape, k_shape, activation=a, dtype=dtype) + conv1d = tvm.IRModule.from_expr(conv1d) + config = conv1d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv1d_bias, dic, param_lst = get_conv1d_bias(x_shape, k_shape, activation=a, dtype=dtype) + conv1d_bias = tvm.IRModule.from_expr(conv1d_bias) + config = conv1d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 2, 3, 3), 16)]: + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + for dilation in [(1, 1), (2, 2)]: + conv2d, dic, param_lst = get_conv2d( + x_shape=x_shape, + k_shape=k_shape, + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + dtype=dtype, + ) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_weights_const(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + k_shape = (16, 32, 3, 3) + conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + x_shape = (1, 3, 8, 8) + k_shape = (16, 3, 3, 3) + conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_pattern(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + k_shape = (16, 32, 3, 3) + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=a, dtype=dtype) + conv2d_bias = tvm.IRModule.from_expr(conv2d_bias) + config = conv2d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype) + conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu) + config = conv2d_bias_bn_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype) + conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu) + config = conv2d_bias_bn_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_transpose(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]: + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + conv2d_transpose, dic, param_lst = get_conv2d_transpose( + x_shape=x_shape, + k_shape=k_shape, + groups=groups, + padding=padding, + strides=strides, + dtype=dtype, + ) + conv2d_transpose = tvm.IRModule.from_expr(conv2d_transpose) + config = conv2d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_transpose_pattern(run_module, dtype="float32"): + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias, dic, param_lst = get_conv2d_transpose_bias(activation=a, dtype=dtype) + conv2d_bias = tvm.IRModule.from_expr(conv2d_bias) + config = conv2d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d(run_module, dtype="float32"): + conv3d, dic, param_lst = get_conv3d(dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d, dic, param_lst = get_conv3d(padding=(0, 0, 0, 1, 1, 1), dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d, dic, param_lst = get_conv3d( + x_shape=(1, 3, 8, 8, 8), k_shape=(16, 3, 3, 3, 3), dtype=dtype + ) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d_pattern(run_module, dtype="float32"): + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_bias, dic, param_lst = get_conv3d_bias(activation=a, dtype=dtype) + conv3d_bias = tvm.IRModule.from_expr(conv3d_bias) + config = conv3d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d_transpose(run_module, dtype="float32"): + conv3d_transpose, dic, param_lst = get_conv3d_transpose(dtype=dtype) + conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) + config = conv3d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_transpose, dic, param_lst = get_conv3d_transpose(strides=(2, 2, 2), dtype=dtype) + conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) + config = conv3d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_transpose, dic, param_lst = get_conv3d_transpose( + strides=(2, 2, 2), output_padding=(1, 1, 1), dtype=dtype + ) + conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) + config = conv3d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d_transpose_pattern(run_module, dtype="float32"): + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_bias, dic, param_lst = get_conv3d_transpose_bias(activation=a, dtype=dtype) + conv3d_bias = tvm.IRModule.from_expr(conv3d_bias) + config = conv3d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_dense(run_module, dtype="float32"): + x_shape = (1, 16) + k_shape = (32, 16) + + dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + dense, dic, param_lst = get_dense(x_shape, k_shape=(1, 16), dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_dense_pattern(run_module, dtype="float32"): + x_shape = (1, 16) + k_shape = (32, 16) + + dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, dtype=dtype) + dense_bias = tvm.IRModule.from_expr(dense_bias) + config = dense_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_pool2d(run_module, dtype="float32"): + def get_graph( + op, + x_shape=(1, 3, 32, 32), + pool_size=(2, 2), + strides=(2, 2), + padding=(0, 0), + ceil_mode=False, + count_include_pad=None, + ): + x = relay.var("x", shape=(x_shape), dtype=dtype) + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + out = tvm.IRModule.from_expr(out) + return out, {"x": x_shape}, [] + + for pool_size in [(2, 2), (3, 3)]: + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]: + for ceil_mode in [False]: + # Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling" + if pool_size == (2, 2) and padding == (0, 0, 1, 1): + continue + for count_include_pad in [False, True]: + # Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding" + if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)): + continue + run_and_verify_func( + get_graph( + relay.nn.avg_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ), + run_module=run_module, + ) + run_and_verify_func( + get_graph( + relay.nn.max_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ), + run_module=run_module, + ) + + +def test_pool3d(run_module, dtype="float32"): + def get_graph( + op, + x_shape=(1, 3, 8, 32, 32), + pool_size=(2, 2, 2), + strides=(2, 2, 2), + padding=(0, 0, 0), + ceil_mode=False, + count_include_pad=None, + dtype="float32", + ): + x = relay.var("x", shape=(x_shape), dtype=dtype) + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + out = tvm.IRModule.from_expr(out) + return out, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.avg_pool3d), run_module=run_module) + run_and_verify_func(get_graph(relay.nn.max_pool3d), run_module=run_module) + run_and_verify_func( + get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1)), run_module=run_module + ) + run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module) + + +def test_prune_dnnl_subgraph(run_module): + """In this test, OP "add" should be offloaded from dnnl codegen.""" + + def get_graph(): + x1 = relay.var("x1", shape=(1, 32, 56, 56)) + x2 = relay.var("x2", shape=(1, 32, 56, 56)) + bias = relay.var("bias", shape=(32,)) + weight = relay.var("weight", shape=(32, 32, 3, 3)) + y = relay.nn.conv2d( + x1, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + ) + y = relay.nn.bias_add(y, bias) + y = relay.nn.relu(y) + y = relay.nn.global_max_pool2d(y) + y = relay.add(y, x2) + dic = { + "x1": (1, 32, 56, 56), + "x2": (1, 32, 56, 56), + "weight": (32, 32, 3, 3), + "bias": (32,), + } + param_lst = ["weight", "bias"] + out = tvm.IRModule.from_expr(y) + return out, dic, param_lst + + run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False) + + +if __name__ == "__main__": + tvm.testing.main() From 8ba43003a00c2ca92017df2ec24ccaef6ddcf636 Mon Sep 17 00:00:00 2001 From: Jian Sheng <84881952+jsheng-jian@users.noreply.github.com> Date: Tue, 7 Jun 2022 23:09:49 -0700 Subject: [PATCH 066/181] minor fix after loading trt engine from disk (#11614) --- src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 554515c456797..18ffdbbbba858 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -376,7 +376,8 @@ class TensorRTRuntime : public JSONRuntimeBase { helper.DeclareField("batch_size", &batch_size); helper.ReadAllFields(&reader); trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; - LOG(INFO) << "finished saving engine and context ... "; + max_batch_size_ = batch_size; + LOG(INFO) << "finished loading engine and context ... "; return true; } From 6dc0c624cdd8fb9d7fdd2194a755b0dffbe2de93 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Wed, 8 Jun 2022 03:00:35 -0700 Subject: [PATCH 067/181] [Relay] Restore dominator check (#11616) It is ok to match a sub-graph which has dataflow outside of the sub-graph, provided all such flows eventually come into the sub-graph. --- src/relay/ir/dataflow_matcher.cc | 28 +++++++---- src/relay/ir/dataflow_matcher_impl.h | 1 + tests/python/contrib/test_cutlass.py | 2 +- tests/python/relay/test_dataflow_pattern.py | 52 ++++++++++++++++++++- 4 files changed, 72 insertions(+), 11 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index df896cb690eb2..b2776a41c50ce 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -609,6 +609,8 @@ void PatternGrouper::VisitExprs() { } void PatternGrouper::CreateGroup(const Expr& expr) { + VLOG(1) << "Creating group for:" << std::endl << PrettyPrint(expr); + int var_number = 0; auto node_map = matcher_->GetMemo(); @@ -696,6 +698,7 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto body = extractor.Mutate(expr); group.function = Function(params, body, NullValue(), Array()); + VLOG(1) << "Candidate extracted function:" << std::endl << PrettyPrint(group.function); group.name = extractor.GetName(); // Check to make sure we aren't overlapping with another group or creating an invalid fusion // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the @@ -708,6 +711,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // Similiarly, if interior nodes in a group are used outside of the group fusing to a single // output would create an invalid graph tranformation, so we block the creation of such groups. auto memo = extractor.GetMemo(); + for (auto kv : memo) { + VLOG(1) << "matched index " << matcher_->expr_to_node(kv.first)->index_; + } + for (auto kv : memo) { // Check to ensure that this node isn't an input or a global if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && @@ -720,16 +727,19 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // if the node isn't the output of the group auto node = matcher_->expr_to_node(kv.first); for (auto* output : node->outputs_) { - // and the node is used by nodes outside of the group if (memo.count(output->ref()) == 0) { - // TODO(mbs): This condition used to also include the following test, which since - // the dominators relation is used back-to-front was always vacuously true. So the - // code is just rejecting the match if a strictly internal node happened to connect - // to an outside node. - ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); - // Exit because nodes in this pattern's body are used outside the pattern, fusing it - // would be invalid - return; + // A node inside the matched group contributes an output to nodes outside of the matched + // group... + auto root = matcher_->expr_to_node(expr); + if (!root->Dominates(output)) { + // ...and the outside dataflow does not come back to the root of the matched group. + // So reject the match since it would create a cycle. + VLOG(1) << "Rejecting group since would create a cycle with output " << output->index_ + << " for root " << root->index_ << " in graph:" << std::endl + << matcher_->expr_graph().ToString(); + return; + } + // else: We'll allow the output to be included in the matched group. } } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index f04190f72e40b..a174d8e34eb7f 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -55,6 +55,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual>& memo() const { return memo_; } + const IndexedGraph& expr_graph() const { return *expr_graph_; } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index c105979402211..8e5238b17399c 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -941,4 +941,4 @@ def test_conv2d_bwd(): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index f0474c9112736..ba066e9a438f9 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1458,7 +1458,6 @@ def concat(*args): def test_partition_fuzzy_function_args(): - func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard() x = relay.var("x") y = relay.var("y") @@ -1790,5 +1789,56 @@ def callback(self, pre, post, node_map): assert tvm.ir.structural_equal(out, expected) +def test_matched_outside_but_dominated(): + """In this example the pattern matches the nn.conv2d/add/multiply flow. Even though the + add output is consumed by the sigmoid, the sigmoid itself is dominated by the multiply. + So partitioning can proceed, all be it with a duplication of the add.""" + in_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] { + %0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC"); + %1 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI"); + %2 = expand_dims(%bias, axis=1, num_newaxis=2); + %3 = expand_dims(%2, axis=0); + %4 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32"); + %5 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC"); + %6 = add(%4, %5); + %7 = sigmoid(%6); + %8 = multiply(%6, %7); + layout_transform(%8, src_layout="NHWC", dst_layout="NCHW") + } + """ + ) + expected_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] { + %2 = expand_dims(%bias, axis=1, num_newaxis=2); + %3 = expand_dims(%2, axis=0); + %4 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC"); + %5 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI"); + %6 = nn.conv2d(%4, %5, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32"); + %7 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC"); + %8 = add(%6, %7); + %9 = sigmoid(%8); + %10 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, %FunctionVar_0_3, PartitionedFromPattern="nn.conv2d_add_multiply_") { + %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32"); + %1 = add(%0, %FunctionVar_0_2); + multiply(%1, %FunctionVar_0_3) + }; + %11 = %10(%4, %5, %7, %9); + layout_transform(%11, src_layout="NHWC", dst_layout="NCHW") + } + """ + ) + pattern = is_op("multiply")( + is_op("add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard()), wildcard() + ) + actual_mod = tvm.IRModule.from_expr(pattern.partition(in_mod["main"])) + actual_mod = relay.transform.InferType()(actual_mod) + tvm.ir.assert_structural_equal(actual_mod, expected_mod) + + if __name__ == "__main__": tvm.testing.main() From b00b1229c881fa6f2f9fe9e44819c9dc3de09f74 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 8 Jun 2022 07:24:36 -0500 Subject: [PATCH 068/181] [Hexagon] Make local symbols visible to loaded modules in RPC server (#11611) The simulator library `libhexagon_rpc_sim.so` contains TVM runtime built into it, but since it's loaded as a "local" library these symbols are not visible to shared libraries loaded by subsequent dlopens. (Same applies to symbols from the C++ runtime.) To make these symbols visible, dlopen the defining libraries as "global". (Re-dlopeninig an already loaded library is a well-defined operation.) --- src/runtime/hexagon/rpc/simulator/rpc_server.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index 29373be542f3f..9b4ce3f11443e 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include @@ -288,7 +289,16 @@ int DISPATCH_FUNCTION_NAME(void* serverp) { return 0; } -int main() { +int main(int argc, char* argv[]) { + // Load C++RT and ourselves as "global" to make all the symbols defined + // there be visible to any subsequent libraries loaded via dlopen. + void* cxx_abi = dlopen("libc++abi.so", RTLD_GLOBAL); + ICHECK(cxx_abi != nullptr); + void* cxx = dlopen("libc++.so", RTLD_GLOBAL); + ICHECK(cxx != nullptr); + void* self = dlopen(argv[0], RTLD_GLOBAL); + ICHECK(self != nullptr); + const auto* api = tvm::runtime::Registry::Get("device_api.hexagon"); ICHECK(api != nullptr); tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api); @@ -308,6 +318,9 @@ int main() { // nothing } + dlclose(self); + dlclose(cxx); + dlclose(cxx_abi); return 0; } From e19cf20054a9fe5049c71b02753c155110b0a6ba Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Wed, 8 Jun 2022 15:21:29 +0200 Subject: [PATCH 069/181] TVMC: Allow to overwrite TVM_CONFIGS_JSON_DIR via environment variables (#11623) If a non-default location for the build directory is used, e.g. set via TVM_LIBRARY_PATH we need to provide the user a way to overwrite CONFIGS_JSON_DIR as well. --- python/tvm/driver/tvmc/config_options.py | 9 +++++++ .../driver/tvmc/test_parse_config_file.py | 27 ++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/config_options.py b/python/tvm/driver/tvmc/config_options.py index ae5616e7245af..c384c89b1a2b6 100644 --- a/python/tvm/driver/tvmc/config_options.py +++ b/python/tvm/driver/tvmc/config_options.py @@ -43,6 +43,15 @@ def get_configs_json_dir() -> str: """ global CONFIGS_JSON_DIR if CONFIGS_JSON_DIR is None: + + # If a non-default location for the build directory is used, e.g. set via TVM_LIBRARY_PATH + # we need to provide the user a way to overwrite CONFIGS_JSON_DIR as well. + if os.environ.get("TVM_CONFIGS_JSON_DIR", None): + user_config_dir = os.environ["TVM_CONFIGS_JSON_DIR"] + if os.path.isdir(user_config_dir): + CONFIGS_JSON_DIR = user_config_dir + return CONFIGS_JSON_DIR + candidate_paths = [] candidate_paths.extend(libinfo.find_lib_path()) # When running from source, the configs directory will be located one directory above the diff --git a/tests/python/driver/tvmc/test_parse_config_file.py b/tests/python/driver/tvmc/test_parse_config_file.py index a80daba3a47ab..6aec2cd453a3e 100644 --- a/tests/python/driver/tvmc/test_parse_config_file.py +++ b/tests/python/driver/tvmc/test_parse_config_file.py @@ -20,7 +20,7 @@ import tvm from tvm.driver.tvmc.main import _main -from tvm.driver.tvmc.config_options import convert_config_json_to_cli +from tvm.driver.tvmc.config_options import convert_config_json_to_cli, get_configs_json_dir def test_parse_json_config_file_one_target(): @@ -153,3 +153,28 @@ def test_tvmc_cl_compile_run_config_file(tflite_mobilenet_v1_1_quant, tmpdir_fac exit_code = _main(tvmc_args) on_error = "Trying to run a MLF archive must fail because it's only supported on micro targets." assert exit_code != 0, on_error + + +def test_tvmc_get_configs_json_dir(tmpdir_factory, monkeypatch): + # Reset global state + monkeypatch.setattr(tvm.driver.tvmc.config_options, "CONFIGS_JSON_DIR", None) + + # Get default directory for reference + default_dir = get_configs_json_dir() + + # Set custom dir which does not exist -> ignore + monkeypatch.setattr(tvm.driver.tvmc.config_options, "CONFIGS_JSON_DIR", None) + monkeypatch.setenv("TVM_CONFIGS_JSON_DIR", "not_a_directory") + result = get_configs_json_dir() + assert_msg = "Non-existant directory passed via TVM_CONFIGS_JSON_DIR should be ignored." + assert result == default_dir, assert_msg + + # Set custom dir which does exist + monkeypatch.setattr(tvm.driver.tvmc.config_options, "CONFIGS_JSON_DIR", None) + configs_dir = tmpdir_factory.mktemp("configs") + monkeypatch.setenv("TVM_CONFIGS_JSON_DIR", str(configs_dir)) + result = get_configs_json_dir() + assert_msg = ( + "Custom value passed via TVM_CONFIGS_JSON_DIR should be used instead of default one." + ) + assert result != default_dir and result is not None, assert_msg From 96a513cd97be4b42acb51d1c9b73288820e90185 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 8 Jun 2022 11:39:42 -0700 Subject: [PATCH 070/181] Patch replay trace. (#11621) --- include/tvm/meta_schedule/search_strategy.h | 4 +++- .../search_strategy/replay_trace.py | 8 +++++++- .../search_strategy/replay_trace.cc | 18 +++++++++++++++--- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index baae22f0d98ec..5e249850f5d5b 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -211,8 +211,10 @@ class SearchStrategy : public runtime::ObjectRef { * \brief Constructor of replay trace search strategy. * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. * \param max_trials_per_task The total number of trials for trace replaying. + * \param max_fail_count The max number of failures during trace replaying. */ - TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task); + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task, + int max_fail_count); /*! * \brief Constructor of replay func search strategy. diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 70461d65f7765..36dbb8734e577 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -33,15 +33,21 @@ class ReplayTrace(SearchStrategy): Number of trials per iteration. max_trials_per_task : int Total number of trials for one task + max_fail_count : int + Max number of failures during trace replaying. """ num_trials_per_iter: int max_trials_per_task: int + max_fail_count: int - def __init__(self, num_trials_per_iter: int, max_trials_per_task: int): + def __init__( + self, num_trials_per_iter: int, max_trials_per_task: int, max_fail_count: int = 100 + ): """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint: disable=no-member num_trials_per_iter, max_trials_per_task, + max_fail_count, ) diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 13f32a744e3a0..355f71455d912 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -60,6 +60,8 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_per_iter; /*! \brief The number of total trials. */ int max_trials_per_task; + /*! \brief The max number of failures during trace replaying. */ + int max_fail_count; /*! \brief The tuning context of the search strategy. */ const TuneContextNode* context_{nullptr}; @@ -71,6 +73,7 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("max_trials_per_task", &max_trials_per_task); + v->Visit("max_fail_count", &max_fail_count); // `context_` is not visited. // `rand_state_` is not visited // `state_` is not visited @@ -136,7 +139,8 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; IRModule mod = this->per_thread_mod_[thread_id]; - for (;;) { + + for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) { int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); tir::Trace trace = design_spaces[design_space_index]; tir::Trace new_trace = tir::Trace(trace->insts, {}); @@ -147,7 +151,13 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure } }; support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker); - return per_task_result; + Array filtered; + filtered.reserve(ed - st); + for (MeasureCandidate result : per_task_result) + if (result.defined()) { + filtered.push_back(result); + } + return filtered; } inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { @@ -155,10 +165,12 @@ inline void ReplayTraceNode::State::NotifyRunnerResults(const Arraynum_trials_per_iter; } -SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int max_trials_per_task) { +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int max_trials_per_task, + int max_fail_count) { ObjectPtr n = make_object(); n->num_trials_per_iter = num_trials_per_iter; n->max_trials_per_task = max_trials_per_task; + n->max_fail_count = max_fail_count; return SearchStrategy(n); } From 9817338508f3f8cd5a444133b4de99ce577c031b Mon Sep 17 00:00:00 2001 From: billishyahao Date: Thu, 9 Jun 2022 03:12:36 +0800 Subject: [PATCH 071/181] [BYOC][DNNL] Enable layer normalization in DNNL byoc. (#11508) * Enable layer normalization in DNNL byoc. * Added unittest for layer norm and make code compatible after introducing TensorRequisite(PR-11345) * Fix lint issue * Fix clang format issue --- python/tvm/relay/op/contrib/dnnl.py | 70 ++++++++++++++++++- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 47 +++++++++++++ tests/python/contrib/test_dnnl.py | 21 ++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 2e975cf49c885..c87a7162b0707 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -41,7 +41,7 @@ from tvm.relay.expr_functor import ExprMutator, ExprVisitor from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op +from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table logger = logging.getLogger("DNNL") @@ -92,6 +92,7 @@ def _func_wrapper(expr): _register_external_op_helper("nn.softmax") _register_external_op_helper("add") _register_external_op_helper("multiply") +_register_external_op_helper("nn.layer_norm") def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): @@ -455,6 +456,7 @@ def visit_call(self, call): "nn.conv3d", "nn.conv3d_transpose", "nn.dense", + "nn.layer_norm", ] ) if isinstance(call.op, tvm.tir.op.Op): @@ -526,3 +528,69 @@ def visit_call(self, call): new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) new_mod = transform.RemoveUnusedFunctions()(new_mod) return new_mod + + +class LayerNormRewrite(DFPatternCallback): + """ + A callback to rewrite the following operators into a single layer normalization operator. + + Pattern #1: + 1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */; + 2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */; + 4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */; + 5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */; + 6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */; + 7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */; + 8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */; + 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 3136, 64), float32] */; + 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 3136, 64), float32] */; + + Pattern #2: + 1 %0 = mean(%input, axis=[-1], keepdims=True); + 2 %1 = variance(%input, %0, axis=[-1], keepdims=True); + 3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), float32] */; + 4 %3 = subtract(%input, %0); + 5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */; + 6 %5 = divide(%3, %4); + 7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 49, 64), float32] */; + 8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 49, 64), float32] */ + + """ + + def __init__(self): + super(LayerNormRewrite, self).__init__() + self.data = wildcard() + self.gamma = wildcard() + self.beta = wildcard() + mu = is_op("mean")(self.data) + diff = is_op("subtract")(self.data, mu) + cdiff = diff | is_op("cast")(diff) + const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0)) + p1 = is_op("power")(cdiff, const_two) + mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu) + eps = is_expr(relay.const(1e-5)) + added_eps = is_op("add")(mp1, eps) + deno = is_op("sqrt")(added_eps) + div_out = is_op("divide")(diff, deno) + weighted = is_op("multiply")(div_out, self.gamma) + added_bias = is_op("add")(weighted, self.beta) + self.pattern = added_bias + + def callback(self, pre, post, node_map): + data = node_map[self.data][0] + gamma = node_map[self.gamma][0] + beta = node_map[self.beta][0] + return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta) + + +def rewrite_layer_norm(mod): + """Rewrite the input graph to replace multiple operators with a TVM native layer normalization + operator so that we can offload them to dnnl layer normalization byoc part. + """ + mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) + return mod diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index a2417f012ea42..db8f25e2a6ea5 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Binary(nid, dnnl::algorithm::binary_add); } else if ("multiply" == op_name) { Binary(nid, dnnl::algorithm::binary_mul); + } else if ("nn.layer_norm" == op_name) { + LayerNorm(nid); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -449,6 +451,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_VARIANCE, var_tr}}); } + void LayerNorm(const size_t& nid) { + auto node = nodes_[nid]; + + auto src_tr = GetInput(nid, 0); + auto gamma_tr = GetInput(nid, 1); + auto beta_tr = GetInput(nid, 2); + auto dst_tr = GetOutput(nid, 0); + + auto axis = GetNodeAttr(node, "axis"); + auto epsilon = GetNodeAttr(node, "epsilon"); + auto center = GetNodeAttr(node, "center"); + auto scale = GetNodeAttr(node, "scale"); + + ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case"; + + // LN description. + auto lnorm_desc = dnnl::layer_normalization_forward::desc( + dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon, + dnnl::normalization_flags::use_scale_shift); + + auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_); + + // Concatenate scale and shift tensors + auto scale_shift_tr = TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid()); + auto sc_sh_dims = scale_shift_tr.dims(); + + ICHECK(sc_sh_dims.size() == 2); + ICHECK(sc_sh_dims[0] == 2); + sc_sh_dims[0] /= 2; + auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze(); + auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze(); + + auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) { + dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc()); + Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); + }; + + register_copy(gamma_tr, scale_tr); + register_copy(beta_tr, shift_tr); + + Submit( + dnnl::layer_normalization_forward(lnorm_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}}); + } + void Pooling(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index babfad4a0c8c7..3e4e831aa594e 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): with tvm.transform.PassContext(opt_level=3): mod = alter_layout_seq(mod) + mod = dnnl.rewrite_layer_norm(mod) + byoc_seq = tvm.transform.Sequential( [ transform.MergeComposite(dnnl.pattern_table()), @@ -454,6 +456,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype return relay.nn.relu(conv2d_bias_bn), dic, param_lst +def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"): + dic = {"input": x_shape} + param_lst = [] + input = relay.var("input", shape=x_shape) + beta = relay.const(np.zeros(x_shape[2]).astype(dtype)) + gamma = relay.const(np.ones(x_shape[2]).astype(dtype)) + out = relay.nn.layer_norm(input, gamma=gamma, beta=beta) + return out, dic, param_lst + + def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) sum_data = relay.const(np.random.randint(x_shape).astype(dtype)) @@ -1032,5 +1044,14 @@ def get_graph(): run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False) +def test_layer_norm(run_module, dtype="float32"): + x_shape = (1, 49, 64) + + ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype) + ln = tvm.IRModule.from_expr(ln) + config = ln, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + if __name__ == "__main__": tvm.testing.main() From 99c113a237cfd3f21d78fbb405160ed8b9b5af0b Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 8 Jun 2022 12:39:09 -0700 Subject: [PATCH 072/181] [COMMUNITY] @tkonolige -> Committer (#11626) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index cfd99ae73f653..8f43ad455e08a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -47,6 +47,7 @@ We do encourage everyone to work anything they are interested in. - [Ziheng Jiang](https://github.com/ZihengJiang) (PMC): @ZihengJiang - relay, compiler - [Manupa Karunaratne](https://github.com/manupa-arm): @manupa-arm - ethos-u, memory planner - [Marisa Kirisame](https://github.com/MarisaKirisame): @MarisaKirisame - relay +- [Tristan Konolige](https://github.com/tkonolige): @tkonolige - profiling, relay, tir, runtime - [Ruihang Lai](https://github.com/MasterJH5574): @MasterJH5574 - tir, tvm-script - [Wuwei Lin](https://github.com/vinx13): @vinx13 - relay, topi - [Yizhi Liu](https://github.com/yzhliu) (PMC): @yzhliu - jvm, topi, relay From 97e681dc3477570b268bd84aae539219e5a0b29c Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 8 Jun 2022 13:23:58 -0700 Subject: [PATCH 073/181] [Hexagon] Add random string to workspace name (#11593) --- python/tvm/contrib/hexagon/build.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 43856253cb180..c659d66bec5db 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -25,6 +25,8 @@ import signal import socket import stat +import random +import string import subprocess from typing import Union @@ -58,7 +60,9 @@ def _get_hexagon_rpc_lib_dir() -> pathlib.Path: def _get_test_directory_name() -> str: """Generate a time-stamped name for use as a test directory name.""" - return datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + date_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + random_str = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) + return f"{date_str}-{random_str}" class HexagonLauncherRPC(metaclass=abc.ABCMeta): From df4f4c0b4bccd775af25967fdf057392c1a2826e Mon Sep 17 00:00:00 2001 From: "Sevin F. Varoglu" Date: Wed, 8 Jun 2022 14:08:06 -0700 Subject: [PATCH 074/181] [ONNX] Add ReduceSum opset13 support (non-dynamic) (#11606) * [ONNX] Add ReduceSum opset13 support (non-dynamic) * Add check * Add support for constant axis * noop * Rework logic --- python/tvm/relay/frontend/onnx.py | 26 ++++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 4 ---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index abfa5629d5534..29c0a778ef6ee 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2270,6 +2270,32 @@ def _impl_v12(cls, inputs, attr, params): return cls._impl_v1(inputs, attr, params) + @classmethod + def _impl_v13(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + + noop_with_empty_axes = attr.get("noop_with_empty_axes", 0) + num_axis = int(infer_type(inputs[1]).checked_type.shape[0]) if inputs[1] is not None else 0 + + if noop_with_empty_axes and num_axis == 0: + return inputs[0] + + if len(inputs) == 2: + if isinstance(inputs[1], _expr.Constant): + # Get axis and unpack scalar + constant_axis = int(inputs[1].data.numpy()[0]) + return cls.run_calculation([inputs[0]], constant_axis, attr.get("keepdims", True)) + + if num_axis > 0: + raise ValueError("Dynamic Reduce is not supported yet!") + + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + return cls.run_calculation([inputs[0]], axis, attr.get("keepdims", True)) + + return cls._impl_v1(inputs, attr, params) + class ReduceMax(Reduce): """Operator converter for ReduceMax.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ebaad9b4cb136..967597f7d12b8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5172,12 +5172,8 @@ def verify_eyelike(indata): "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_negative_delta_expanded", - "test_reduce_sum_default_axes_keepdims_example", - "test_reduce_sum_default_axes_keepdims_random", "test_reduce_sum_do_not_keepdims_example", "test_reduce_sum_do_not_keepdims_random", - "test_reduce_sum_empty_axes_input_noop_example", - "test_reduce_sum_empty_axes_input_noop_random", "test_reduce_sum_keepdims_example", "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", From 2f9d9b4e5c7dcb3c9879fb2496f1f50e85b9c55a Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 9 Jun 2022 07:31:55 +0300 Subject: [PATCH 075/181] [OpenCL] Implement conv2d_winograd algorithm for Adreno (#11543) * Implement conv2d_winograd algorithm for Adreno * Implement gtest for OpenCL texture pool * Implement conv2d_nhwc_winograd for Adreno * Minor refactoring * Fix lint * Apply comments * Apply comments * Fix lint --- CMakeLists.txt | 16 + cmake/modules/LibInfo.cmake | 1 + cmake/modules/OpenCL.cmake | 6 + python/tvm/relay/op/strategy/adreno.py | 99 +++- python/tvm/topi/adreno/__init__.py | 2 + python/tvm/topi/adreno/conv2d_alter_op.py | 218 +++++++- .../tvm/topi/adreno/conv2d_nchw_winograd.py | 128 +++++ .../tvm/topi/adreno/conv2d_nhwc_winograd.py | 128 +++++ .../tvm/topi/adreno/conv2d_winograd_common.py | 512 ++++++++++++++++++ python/tvm/topi/adreno/utils.py | 28 + src/runtime/opencl/texture_pool.cc | 191 ++++--- src/runtime/texture.h | 22 +- src/support/libinfo.cc | 5 + .../opencl/opencl_texture_pool_test.cc | 151 ++++++ tests/cpp-runtime/opencl/run_gtests.cc | 60 ++ tests/python/contrib/test_opencl/conftest.py | 29 + .../contrib/test_opencl/test_run_gtests.py | 55 ++ .../python/relay/test_conv2d_nchw_texture.py | 43 ++ .../python/relay/test_conv2d_nhwc_texture.py | 43 ++ tests/python/relay/utils/adreno_utils.py | 1 + 20 files changed, 1638 insertions(+), 100 deletions(-) create mode 100644 python/tvm/topi/adreno/conv2d_nchw_winograd.py create mode 100644 python/tvm/topi/adreno/conv2d_nhwc_winograd.py create mode 100644 python/tvm/topi/adreno/conv2d_winograd_common.py create mode 100644 tests/cpp-runtime/opencl/opencl_texture_pool_test.cc create mode 100644 tests/cpp-runtime/opencl/run_gtests.cc create mode 100644 tests/python/contrib/test_opencl/conftest.py create mode 100644 tests/python/contrib/test_opencl/test_run_gtests.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5352eddd25987..b4d6e18aad630 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ endif() # Alernatively, use cmake -DOPTION=VALUE through command-line. tvm_option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF) +tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest) tvm_option(USE_VULKAN "Build with Vulkan" OFF) @@ -609,6 +610,18 @@ if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTES include_directories("${USE_HEXAGON_GTEST}/include") endif() +if(USE_OPENCL AND DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + include(FetchContent) + FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + target_link_libraries(tvm_runtime PUBLIC gtest) + target_link_libraries(tvm PUBLIC gtest) + include_directories("${USE_OPENCL_GTEST}/include") + include_directories("${USE_OPENCL_GTEST}/googletest/include") + message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") +endif() + # Set flags for clang include(cmake/modules/ClangFlags.cmake) set(CRC16_INCLUDE_PATH "3rdparty/libcrc/include") @@ -668,6 +681,9 @@ install(TARGETS tvm_runtime EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_S if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) endif() +if(USE_OPENCL AND DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) +endif() if (INSTALL_DEV) install( diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 76ddbede8ac06..3e6b3c787f656 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -89,6 +89,7 @@ function(add_lib_info src_file) TVM_INFO_USE_MSVC_MT="${USE_MSVC_MT}" TVM_INFO_USE_NNPACK="${USE_NNPACK}" TVM_INFO_USE_OPENCL="${USE_OPENCL}" + TVM_INFO_USE_OPENCL_GTEST="${USE_OPENCL_GTEST}" TVM_INFO_USE_OPENMP="${USE_OPENMP}" TVM_INFO_USE_PAPI="${USE_PAPI}" TVM_INFO_USE_PROFILER="${USE_PROFILER}" diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index 648e83f575d18..430af7e8722c8 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -55,6 +55,12 @@ if(USE_OPENCL) message(STATUS "Build with OpenCL support") tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES}) + + if(DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + file_glob_append(RUNTIME_OPENCL_SRCS + "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/opencl/*.cc" + ) + endif() list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS}) else() list(APPEND COMPILER_SRCS src/target/opt/build_opencl_off.cc) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index a783440bb38cc..01b3935a6f1bc 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -28,6 +28,7 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy = _op.OpStrategy() data, kernel = inputs dilation_h, dilation_w = attrs.get_int_tuple("dilation") + stride_h, stride_w = attrs.get_int_tuple("strides") groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout @@ -38,6 +39,28 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): if (data_layout == "NCHW" and kernel_layout == "OIHW") or ( data_layout == "NCHW4c" and kernel_layout == "OIHW4o" ): + if len(kernel.shape) == 4: + _, _, kh, kw = get_const_tuple(kernel.shape) + else: + _, _, kh, kw, _ = get_const_tuple(kernel.shape) + if ( + (2 < kh < 8 and 2 < kw < 8 and kh == kw) + and (stride_h == 1 and stride_w == 1) + and (dilation_h == 1 and dilation_w == 1) + ): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd), + name="conv2d_nchw_winograd.image2d", + plevel=25, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_acc32), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd_acc32), + name="conv2d_nchw_winograd_acc32.image2d", + plevel=30, + ) if out_type.dtype == "float16": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc), @@ -48,12 +71,34 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32), - name="conv2d_nchwc_tpack.image2d", + name="conv2d_nchwc_acc32.image2d", plevel=20, ) elif (data_layout == "NHWC" and kernel_layout == "HWIO") or ( data_layout == "NHWC4c" and kernel_layout == "HWIO4o" ): + if len(kernel.shape) == 4: + kh, kw, _, _ = get_const_tuple(kernel.shape) + else: + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if ( + (2 < kh < 8 and 2 < kw < 8 and kh == kw) + and (stride_h == 1 and stride_w == 1) + and (dilation_h == 1 and dilation_w == 1) + ): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd), + name="conv2d_nhwc_winograd.image2d", + plevel=25, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_acc32), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd_acc32), + name="conv2d_nhwc_winograd_acc32.image2d", + plevel=30, + ) if out_type.dtype == "float16": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nhwc), @@ -153,6 +198,58 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): return strategy +@conv2d_winograd_without_weight_transfrom_strategy.register("adreno") +def conv2d_winograd_without_weight_transfrom_strategy_adreno(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transfrom adreno strategy""" + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs.data_layout + assert dilation == (1, 1), "Do not support dilate now" + assert groups == 1, "Do not supoort arbitrary group number" + strategy = _op.OpStrategy() + if layout in ("NCHW", "NCHW4c"): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform + ), + name="conv2d_nchw_winograd_without_weight_transform.image2d", + plevel=35, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform_acc32), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform_acc32 + ), + name="conv2d_nchw_winograd_without_weight_transform_acc32.image2d", + plevel=40, + ) + elif layout in ("NHWC", "NHWC4c"): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform + ), + name="conv2d_nhwc_winograd_without_weight_transform.image2d", + plevel=35, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform_acc32), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform_acc32 + ), + name="conv2d_nhwc_winograd_without_weight_transform_acc32.image2d", + plevel=40, + ) + else: + raise RuntimeError( + "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + ) + return strategy + + @schedule_pool.register("adreno") def schedule_pool_adreno(attrs, outs, target): """schedule pooling ops for adreno""" diff --git a/python/tvm/topi/adreno/__init__.py b/python/tvm/topi/adreno/__init__.py index 6c9b7463c1d4e..57a9013b1a2ab 100644 --- a/python/tvm/topi/adreno/__init__.py +++ b/python/tvm/topi/adreno/__init__.py @@ -23,3 +23,5 @@ from .depthwise_conv2d_nhwc import * from .pooling import * from .conv2d_alter_op import * +from .conv2d_nchw_winograd import * +from .conv2d_nhwc_winograd import * diff --git a/python/tvm/topi/adreno/conv2d_alter_op.py b/python/tvm/topi/adreno/conv2d_alter_op.py index e8944093c0f54..16573991e09c5 100644 --- a/python/tvm/topi/adreno/conv2d_alter_op.py +++ b/python/tvm/topi/adreno/conv2d_alter_op.py @@ -25,6 +25,7 @@ from tvm import relay from tvm import autotvm from ..utils import get_const_tuple +from .utils import infer_tile_size from ..nn import conv2d_alter_layout logger = logging.getLogger("topi") @@ -58,7 +59,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): kernel_layout = attrs["kernel_layout"] data_tensor, kernel_tensor = tinfos data_dtype = data_tensor.dtype - kernel_dtype = kernel_tensor.dtype out_dtype = out_type.dtype if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest): @@ -70,12 +70,228 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): ) workload = autotvm.task.get_workload(outs) if workload is None: + if impl.name.find("winograd") != -1: + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + assert (data_layout == "NCHW" and kernel_layout == "OIHW") or ( + data_layout == "NHWC" and kernel_layout == "HWIO" + ) + if data_layout == "NCHW": + N, CI, H, W = get_const_tuple(data_tensor.shape) + CO, _, KH, KW = get_const_tuple(kernel_tensor.shape) + weight = inputs[1] + else: + N, H, W, CI = get_const_tuple(data_tensor.shape) + KH, KW, _, CO = get_const_tuple(kernel_tensor.shape) + weight = relay.layout_transform(inputs[1], "HWIO", "OIHW") + + # Pre-compute weight transformation in winograd + tile_size = infer_tile_size(data_tensor, data_layout) + + # alpha, alpha, CO, CI + weight = relay.nn.contrib_conv2d_winograd_weight_transform( + weight, tile_size=tile_size + ) + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) return None cfg = dispatch_ctx.query(target, workload) topi_tmpl = workload[0] + if "conv2d_nchw_winograd" in topi_tmpl: + suffix = "_acc32" if "acc32" in topi_tmpl else "" + wkl_name = "conv2d_nchw_winograd_without_weight_transform" + suffix + ".image2d" + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + tile_size = infer_tile_size(data_tensor, data_layout) + if len(data_tensor.shape) == 5: + assert data_layout == "NCHW4c" and kernel_layout == "OIHW4o" + N, CI, H, W, CB = get_const_tuple(data_tensor.shape) + CO, _, KH, KW, COB = get_const_tuple(kernel_tensor.shape) + weight = relay.layout_transform(inputs[1], "OIHW4o", "OIHW") + weight = relay.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) + weight = relay.layout_transform(weight, "HWOI", "HWIO4o") + + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO * COB + + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI * CB, CO, COB), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + assert data_layout == "NCHW" and kernel_layout == "OIHW" + N, CI, H, W = get_const_tuple(data_tensor.shape) + CO, _, KH, KW = get_const_tuple(kernel_tensor.shape) + + # pre-compute weight transformation in winograd + weight = relay.nn.contrib_conv2d_winograd_weight_transform(inputs[1], tile_size=tile_size) + weight = relay.transpose(weight, axes=[2, 3, 0, 1]) # HWOI -> OIHW + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + + # Store the same config for the altered operator (workload) + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO), dtype=kernel_tensor.dtype + ) + in_channel_block = CI % 4 + if in_channel_block == 0: + in_channel_block = 4 + num_filter_block = CO % 4 + if num_filter_block == 0: + num_filter_block = 4 + + if in_channel_block != 4 or num_filter_block != 4: + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + new_attrs["data_layout"] = "NCHW%dc" % in_channel_block + # (oc, ic, h, w) -> (h, w, ic, oc // 4, oc % 4) + new_attrs["kernel_layout"] = "HWIO%do" % num_filter_block + new_attrs["out_layout"] = "NCHW%dc" % num_filter_block + # Store altered operator's config + new_data = te.placeholder( + (N, CI // in_channel_block, H, W, in_channel_block), dtype=data_dtype + ) + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO // num_filter_block, num_filter_block), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [ + new_data, + new_weight, + strides, + padding, + dilation, + out_dtype, + ], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + if "conv2d_nhwc_winograd" in topi_tmpl: + suffix = "_acc32" if "acc32" in topi_tmpl else "" + wkl_name = "conv2d_nhwc_winograd_without_weight_transform" + suffix + ".image2d" + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + tile_size = infer_tile_size(data_tensor, data_layout) + if len(data_tensor.shape) == 5: + assert data_layout == "NHWC4c" and kernel_layout == "HWIO4o" + N, CI, H, W, CB = get_const_tuple(data_tensor.shape) + KH, KW, _, CO, COB = get_const_tuple(kernel_tensor.shape) + weight = relay.layout_transform(inputs[1], "HWIO4o", "OIHW") + weight = relay.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) + weight = relay.layout_transform(weight, "HWOI", "HWIO4o") + + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO * COB + + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI * CB, CO, COB), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + assert data_layout == "NHWC" and kernel_layout == "HWIO" + N, H, W, CI = get_const_tuple(data_tensor.shape) + KH, KW, _, CO = get_const_tuple(kernel_tensor.shape) + + # pre-compute weight transformation in winograd + weight = relay.layout_transform(inputs[1], "HWIO", "OIHW") + weight = relay.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) + weight = relay.transpose(weight, axes=[0, 1, 3, 2]) # HWOI -> HWIO + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + + # Store the same config for the altered operator (workload) + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO), dtype=kernel_tensor.dtype + ) + in_channel_block = CI % 4 + if in_channel_block == 0: + in_channel_block = 4 + num_filter_block = CO % 4 + if num_filter_block == 0: + num_filter_block = 4 + + if in_channel_block != 4 or num_filter_block != 4: + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + new_attrs["data_layout"] = "NHWC%dc" % in_channel_block + # (oc, ic, h, w) -> (h, w, ic, oc // 4, oc % 4) + new_attrs["kernel_layout"] = "HWIO%do" % num_filter_block + new_attrs["out_layout"] = "NHWC%dc" % num_filter_block + # Store altered operator's config + new_data = te.placeholder( + (N, H, W, CI // in_channel_block, in_channel_block), dtype=data_dtype + ) + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO // num_filter_block, num_filter_block), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [ + new_data, + new_weight, + strides, + padding, + dilation, + out_dtype, + ], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + if "conv2d_nchwc" in topi_tmpl: # covers both conv2d_nchwc and depthwise_conv2d_nchwc if data_layout == "NCHW" and kernel_layout == "OIHW": batch, in_channels, in_height, in_width = data_tensor.shape diff --git a/python/tvm/topi/adreno/conv2d_nchw_winograd.py b/python/tvm/topi/adreno/conv2d_nchw_winograd.py new file mode 100644 index 0000000000000..16f7cb8b19d95 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_nchw_winograd.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Winograd NCHW template for Adreno backend""" + +import logging +from tvm import autotvm +from .conv2d_winograd_common import conv2d_winograd_comp, schedule_conv2d_winograd_impl + + +logger = logging.getLogger("conv2d_nchw_winograd") + + +@autotvm.register_topi_compute("conv2d_nchw_winograd.image2d") +def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_compute("conv2d_nchw_winograd_acc32.image2d") +def conv2d_nchw_winograd_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd.image2d") +def schedule_conv2d_nchw_winograd(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16") + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd_acc32.image2d") +def schedule_conv2d_nchw_winograd_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32") + + +@autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform.image2d") +def conv2d_nchw_winograd_without_weight_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform_acc32.image2d") +def conv2d_nchw_winograd_without_weight_transform_acc32( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform.image2d") +def schedule_conv2d_nchw_winograd_without_weight_transform(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform_acc32.image2d") +def schedule_conv2d_nchw_winograd_without_weight_transform_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) + + +def conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + return conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, "NCHW" + ) diff --git a/python/tvm/topi/adreno/conv2d_nhwc_winograd.py b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py new file mode 100644 index 0000000000000..bfe385f210a49 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Winograd NHWC template for Adreno backend""" + +import logging +from tvm import autotvm +from .conv2d_winograd_common import conv2d_winograd_comp, schedule_conv2d_winograd_impl + + +logger = logging.getLogger("conv2d_nhwc_winograd") + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd.image2d") +def conv2d_nhwc_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_acc32.image2d") +def conv2d_nhwc_winograd_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd.image2d") +def schedule_conv2d_nhwc_winograd(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16") + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_acc32.image2d") +def schedule_conv2d_nhwc_winograd_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32") + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_without_weight_transform.image2d") +def conv2d_nhwc_winograd_without_weight_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d") +def conv2d_nhwc_winograd_without_weight_transform_acc32( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform.image2d") +def schedule_conv2d_nhwc_winograd_without_weight_transform(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d") +def schedule_conv2d_nhwc_winograd_without_weight_transform_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) + + +def conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + return conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, "NHWC" + ) diff --git a/python/tvm/topi/adreno/conv2d_winograd_common.py b/python/tvm/topi/adreno/conv2d_winograd_common.py new file mode 100644 index 0000000000000..494b691a7f076 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_winograd_common.py @@ -0,0 +1,512 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Common Winograd implementation for Adreno backend""" + +import tvm +from tvm import te +from tvm import autotvm + +from tvm.topi import nn +from tvm.topi.utils import get_const_int, get_const_tuple, traverse_inline +from ..nn.winograd_util import winograd_transform_matrices +from .utils import ( + split_to_chunks, + pack_input, + pack_filter, + bind_data_copy, + get_texture_storage, + infer_tile_size, +) + + +def conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, layout +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + layout: str + NHWC or NCHW values are accepted + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + assert layout in ("NCHW", "NHWC") + tile_size = infer_tile_size(data, layout) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + convert_from4d = False + if len(data.shape) == 4: + if layout == "NCHW": + N, DCI, H, W = get_const_tuple(data.shape) + else: + N, H, W, DCI = get_const_tuple(data.shape) + if not pre_computed: + if layout == "NCHW": + out_channels, CI, KH, KW = get_const_tuple(kernel.shape) + else: + KH, KW, CI, out_channels = get_const_tuple(kernel.shape) + else: + alpha, _, CI, out_channels = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + + in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(CI, 4) + out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) + if autotvm.GLOBAL_SCOPE.in_tuning is True: + if layout == "NCHW": + dshape = (N, in_channel_chunks, H, W, in_channel_block) + else: + dshape = (N, H, W, in_channel_chunks, in_channel_block) + if not pre_computed: # kernel tensor is raw tensor, do strict check + if layout == "NCHW": + kshape = (out_channel_chunks, CI, KH, KW, out_channel_block) + else: + kshape = (KH, KW, CI, out_channel_chunks, out_channel_block) + else: + kshape = (alpha, alpha, CI, out_channel_chunks, out_channel_block) + data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder") + kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") + else: + convert_from4d = True + data = pack_input( + data, layout, N, in_channel_chunks, in_channel_block, in_channel_tail, H, W + ) + kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" + if not pre_computed: # kernel tensor is raw tensor, do strict check + kernel = pack_filter( + kernel, + kernel_layout, + out_channel_chunks, + out_channel_block, + out_channel_tail, + CI, + in_channel_chunks, + in_channel_block, + in_channel_tail, + KH, + KW, + ) + else: + kernel = pack_filter( + kernel, + "HWIO", + out_channel_chunks, + out_channel_block, + out_channel_tail, + CI, + in_channel_chunks, + in_channel_block, + in_channel_tail, + alpha, + alpha, + ) + if layout == "NCHW": + N, DCI, H, W, CB = get_const_tuple(data.shape) + else: + N, H, W, DCI, CB = get_const_tuple(data.shape) + if not pre_computed: # kernel tensor is raw tensor, do strict check + if layout == "NCHW": + CO, CI, KH, KW, COB = get_const_tuple(kernel.shape) + else: + KH, KW, CI, CO, COB = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + alpha, _, CI, CO, COB = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + if isinstance(N, tvm.tir.Any): + N = tvm.te.size_var("n") + + if not isinstance(H, int) or not isinstance(W, int): + raise RuntimeError( + "adreno winograd conv2d doesn't support dynamic input\ + height or width." + ) + + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + if layout == "NCHW": + data_pad = nn.pad(data, (0, 0, pt, pl, 0), (0, 0, pb, pr, 0), name="data_pad") + else: + data_pad = nn.pad(data, (0, pt, pl, 0, 0), (0, pb, pr, 0, 0), name="data_pad") + + r = KW + m = tile_size + A, B, G = winograd_transform_matrices(m, r, out_dtype) + + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + + P = N * nH * nW if isinstance(N, int) else nH * nW + + # transform kernel + if not pre_computed: + r_kh = te.reduce_axis((0, KH), name="r_kh") + r_kw = te.reduce_axis((0, KW), name="r_kw") + if layout == "NCHW": + kernel_pack = te.compute( + (alpha, alpha, CI, CO, COB), + lambda eps, nu, ci, co, cob: te.sum( + kernel[co][ci][r_kh][r_kw][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + ), + name="kernel_pack", + ) + else: + kernel_pack = te.compute( + (alpha, alpha, CI, CO, COB), + lambda eps, nu, ci, co, cob: te.sum( + kernel[r_kh][r_kw][ci][co][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + ), + name="kernel_pack", + ) + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + if layout == "NCHW": + N, CI, H, W, CB = get_const_tuple(data.shape) + else: + N, H, W, CI, CB = get_const_tuple(data.shape) + + # pack input tile + if layout == "NCHW": + input_tile = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][c][ + idxmod(idxdiv(p, nW), nH) * m + eps + ][idxmod(p, nW) * m + nu][cb], + name="d", + ) + else: + input_tile = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][ + idxmod(idxdiv(p, nW), nH) * m + eps + ][idxmod(p, nW) * m + nu][c][cb], + name="d", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_a") + data_pack = te.compute( + (P, CI, alpha, alpha, CB), + lambda p, ci, eps, nu, cb: te.sum( + input_tile[r_a][r_b][ci][p][cb] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + ) + + # repack transformed data + data_pack_trans = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pack[p][c][eps][nu][cb], + name="data_pack_trans", + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + cb = te.reduce_axis((0, CB), name="cb") + bgemm = te.compute( + (alpha, alpha, CO, P, COB), + lambda eps, nu, co, p, cob: te.sum( + ( + kernel_pack[eps][nu][ci * CB + cb][co][cob] * data_pack_trans[eps][nu][ci][p][cb] + ).astype(args["accumulator"]), + axis=[ci, cb], + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_a") + inverse = te.compute( + (CO, P, m, m, COB), + lambda co, p, vh, vw, cob: te.sum( + bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * A[r_b][vw]).astype(args["accumulator"]), + axis=[r_a, r_b], + ), + name="inverse", + ) + + # output + if layout == "NCHW": + if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: + output = te.compute( + (N, out_channels, H, W), + lambda n, c, h, w: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][c % CB].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + output = te.compute( + (N, CO, H, W, COB), + lambda n, co, h, w, cob: inverse[co][ + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m) + ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: + output = te.compute( + (N, H, W, out_channels), + lambda n, h, w, c: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][c % CB].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + output = te.compute( + (N, H, W, CO, COB), + lambda n, h, w, co, cob: inverse[co][ + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m) + ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + + if isinstance(N, int): + cfg.add_flop(2 * N * CO * COB * H * W * CI * CB * KH * KW) + + return output + + +def schedule_conv2d_winograd_impl(cfg, outs, tag, pre_computed=False): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == tag: + schedule_conv2d_winograd(cfg, s, op.output(0), pre_computed=pre_computed) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def schedule_conv2d_winograd(cfg, s, output, pre_computed): + """Schedule winograd template""" + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + kernel_pack, data_pack_trans = s[bgemm].op.input_tensors + data_pack = s[data_pack_trans].op.input_tensors[0] + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # data transform + s[B].compute_inline() + s[A].compute_inline() + + # probably will improve real topology execution + if autotvm.GLOBAL_SCOPE.in_tuning: + # Padding to texture + AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile]) + bind_data_copy(s[AA]) + + s[input_tile].compute_inline() + + OL = s.cache_write(data_pack, "local") + c, p, eps, nu, cb = s[data_pack].op.axis + fused = s[data_pack].fuse(c, p, eps, nu) + bx, tx = s[data_pack].split(fused, 128) + s[data_pack].vectorize(cb) + s[data_pack].bind(bx, te.thread_axis("blockIdx.x")) + s[data_pack].bind(tx, te.thread_axis("threadIdx.x")) + + _, _, eps, nu, cb = s[OL].op.axis + r_a, r_b = s[OL].op.reduce_axis + s[OL].unroll(eps) + s[OL].unroll(nu) + s[OL].unroll(r_a) + s[OL].unroll(r_b) + s[OL].vectorize(cb) + s[OL].compute_at(s[data_pack], tx) + s[data_pack].set_scope(get_texture_storage(data_pack.shape)) + + s[data_pack_trans].compute_inline() + + # transform kernel + if not pre_computed: + kernel, G = s[kernel_pack].op.input_tensors + eps, nu, ci, co, cob = s[kernel_pack].op.axis + if autotvm.GLOBAL_SCOPE.in_tuning: + # skip this part during tuning to make recrods accurate + # this part will be pre-computed during pre-compute optimization pass + s[G].pragma(s[G].op.axis[0], "debug_skip_region") + s[kernel_pack].pragma(eps, "debug_skip_region") + else: + s[G].compute_inline() + r_a, r_b = s[kernel_pack].op.reduce_axis + for axis in [eps, nu, r_a, r_b]: + s[kernel_pack].unroll(axis) + + fused = s[kernel_pack].fuse(ci, co) + bb, tt = s[kernel_pack].split(fused, 128) + s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob) + s[kernel_pack].vectorize(cob) + s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag: + # manage scheduling of datacopy + pack_data = pad_data.op.input_tensors[0] + bind_data_copy(s[pack_data]) + bind_data_copy(s[kernel]) + elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + s[pad_data].compute_inline() + + ##### space definition begin ##### + cfg.define_knob("auto_unroll_max_step", [0, 4, 16]) + b1, b2, y, x, cb = s[bgemm].op.axis + rcc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + cfg.define_split( + "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 8 + ) + cfg.define_split( + "tile_x", + x, + num_outputs=3, + filter=lambda entry: entry.size[2] <= 64 and entry.size[1] >= 4 and entry.size[1] <= 8, + ) + cfg.define_split("tile_rc", rcc, num_outputs=2) + # TODO: Uncomment the following lines when multi_filter will be introduced + # cfg.multi_filter( + # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024) + # ) + ##### space definition end ##### + + # batch gemm + OL = s.cache_write(bgemm, "local") + if ( + autotvm.GLOBAL_SCOPE.in_tuning + or isinstance(kernel.op, tvm.te.ComputeOp) + and "filter_pack" in kernel.op.tag + ): + BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL]) + bind_data_copy(s[BB]) + + by = s[bgemm].fuse(b1, b2, y) + + # tile and bind spatial axes + bgemm_scope, by = s[bgemm].split(by, nparts=1) + by, vy, ty = cfg["tile_y"].apply(s, bgemm, by) + bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x) + s[bgemm].bind(by, te.thread_axis("blockIdx.y")) + s[bgemm].bind(bx, te.thread_axis("blockIdx.x")) + s[bgemm].bind(vy, te.thread_axis("vthread")) + s[bgemm].bind(vx, te.thread_axis("vthread")) + s[bgemm].bind(ty, te.thread_axis("threadIdx.y")) + s[bgemm].bind(tx, te.thread_axis("threadIdx.x")) + s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb) + s[bgemm].vectorize(cb) + s[bgemm].set_scope(get_texture_storage(bgemm.shape)) + + # tile reduction axes + s[OL].compute_at(s[bgemm], tx) + b1, b2, y, x, cb = s[OL].op.axis + (rcc, rcb) = s[OL].op.reduce_axis + b = s[OL].fuse(b1, b2) + s[OL].reorder(b, y, x, rcc, rcb, cb) + # s[OL].unroll(rcb) + s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[OL].pragma(rcb, "unroll_explicit", True) + s[OL].vectorize(cb) + + # schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope("local") + output = s.outputs[0] + + m = alpha - 3 + 1 + if len(s[output].op.axis) == 4: + n, co, h, w = s[output].op.axis + else: + n, co, h, w, _ = s[output].op.axis + ho, wo, hi, wi = s[output].tile(h, w, m, m) + inverse_scope, n = s[output].split(n, nparts=1) + + fused = s[output].fuse(n, co, ho, wo) + bb, tt = s[output].split(fused, 128) + + s[output].bind(bb, te.thread_axis("blockIdx.x")) + s[output].bind(tt, te.thread_axis("threadIdx.x")) + + if OL is not None: + s[OL].compute_at(s[output], tt) + + co, p, vh, vw, cb = s[inverse].op.axis + r_a, r_b = s[inverse].op.reduce_axis + for axis in [vh, vw, r_a, r_b]: + s[inverse].unroll(axis) + s[inverse].vectorize(cb) + s[inverse].compute_at(s[output], tt) + + return s diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py index 727741c11fd3f..78a992e56a0f9 100644 --- a/python/tvm/topi/adreno/utils.py +++ b/python/tvm/topi/adreno/utils.py @@ -547,3 +547,31 @@ def get_texture_storage(shape): return "global.texture-nhwc" else: return "global.texture-weight" + + +def infer_tile_size(data, layout): + """Compute the tile size for Winograd algorithm + + Parameters + ---------- + data: tvm.te.Tensor + Data tensor + + layout: string + Layout of data tebsir + NCHW, NCHW4c, NHWC or NHWC4c are acceptable + + Returns + ------- + tile_size : int + Calculated tile size + """ + assert layout in ("NCHW", "NCHW4c", "NHWC", "NHWC4c"), "Incompatible layout" + if layout in ("NCHW", "NCHW4c"): + H = get_const_tuple(data.shape)[2] + else: + H = get_const_tuple(data.shape)[1] + + if H % 8 == 0: + return 4 + return 2 diff --git a/src/runtime/opencl/texture_pool.cc b/src/runtime/opencl/texture_pool.cc index e7f6655c41142..0b9477f2d4ea3 100644 --- a/src/runtime/opencl/texture_pool.cc +++ b/src/runtime/opencl/texture_pool.cc @@ -29,113 +29,112 @@ namespace tvm { namespace runtime { -class TexturePool::Pool { - public: - Pool() = default; - void* Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint) { - Entry e; - e.data = nullptr; - if (free_list_.size() != 0) { - Entry new_mem; - int64_t min_added_size_x = std::numeric_limits::max(); - int64_t min_added_size_y = std::numeric_limits::max(); - int64_t min_wasted_size_x = std::numeric_limits::max(); - int64_t min_wasted_size_y = std::numeric_limits::max(); - std::vector::iterator best_mem; - for (auto it = free_list_.begin(); it != free_list_.end(); ++it) { - if (it->type.code != type_hint.code) { - continue; - } - new_mem.x = std::max(it->x, width); - new_mem.y = std::max(it->y, height); - int64_t added_size_x = new_mem.x - it->x; - int64_t added_size_y = new_mem.y - it->y; - int64_t wasted_size_x = new_mem.x - width; - int64_t wasted_size_y = new_mem.y - height; - // Minimize added size first and wasted size thereafter - if ((min_added_size_x > 0 && added_size_x < min_added_size_x) || - (min_added_size_y > 0 && added_size_y < min_added_size_y) || - (min_added_size_x == added_size_x && wasted_size_x < min_wasted_size_x) || - (min_added_size_y == added_size_y && wasted_size_y < min_wasted_size_y)) { - min_added_size_x = added_size_x; - min_added_size_y = added_size_y; - min_wasted_size_x = wasted_size_x; - min_wasted_size_y = wasted_size_y; - best_mem = it; - } +void* Pool2D::Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, + DLDataType type_hint) { + Entry e; + Entry new_mem; + // Processed several experiments and found that when we are trying to fit + // small texture to too big texture then it may lead to the performance + // degradation. + // Coefficient at 5 looks like robust variant for reusing textures. + const int64_t max_ratio = 5; + e.data = nullptr; + std::vector::iterator best_mem; + if (free_list_.size() != 0) { + int64_t min_added_size_x = std::numeric_limits::max(); + int64_t min_added_size_y = std::numeric_limits::max(); + int64_t min_wasted_size_x = std::numeric_limits::max(); + int64_t min_wasted_size_y = std::numeric_limits::max(); + for (auto it = free_list_.begin(); it != free_list_.end(); ++it) { + if (it->type.code != type_hint.code) { + continue; } - - if (min_added_size_x == 0 && min_added_size_y == 0) { - // use existing block - e = *best_mem; - free_list_.erase(best_mem); - } else if (static_cast(min_added_size_x) <= width || - static_cast(min_added_size_y) <= height) { - // if added size is less or equal to - // what is needed by alloc, then grow entry - device->FreeDataSpace(dev, best_mem->data); - free_list_.erase(best_mem); - new_mem.type = type_hint; - std::vector shape{int64_t(new_mem.y), int64_t(new_mem.x), 4}; - new_mem.data = device->AllocDataSpace(dev, shape.size(), shape.data(), new_mem.type, - Optional("global.texture")); - e = new_mem; + // avoid reusing too small and too big textures + if (width / it->x > max_ratio || it->x / width > max_ratio || height / it->y > max_ratio || + it->y / height > max_ratio) { + continue; + } + int64_t new_width = std::max(it->x, width); + int64_t new_height = std::max(it->y, height); + int64_t added_size_x = new_width - it->x; + int64_t added_size_y = new_height - it->y; + int64_t wasted_size_x = new_width - width; + int64_t wasted_size_y = new_height - height; + // Minimize added size first and wasted size thereafter + if ((min_added_size_x > 0 && added_size_x < min_added_size_x) || + (min_added_size_y > 0 && added_size_y < min_added_size_y) || + (min_added_size_x == added_size_x && wasted_size_x < min_wasted_size_x) || + (min_added_size_y == added_size_y && wasted_size_y < min_wasted_size_y)) { + min_added_size_x = added_size_x; + min_added_size_y = added_size_y; + min_wasted_size_x = wasted_size_x; + min_wasted_size_y = wasted_size_y; + best_mem = it; + new_mem.x = new_width; + new_mem.y = new_height; } } - if (e.data == nullptr) { - // create new block - std::vector shape{int64_t(height), int64_t(width), 4}; - e.data = device->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, - Optional("global.texture")); - e.x = width; - e.y = height; - e.type = type_hint; + if (min_added_size_x == 0 && min_added_size_y == 0) { + // use existing block + e = *best_mem; + free_list_.erase(best_mem); + } else if (static_cast(min_added_size_x) <= width || + static_cast(min_added_size_y) <= height) { + // if added size is less or equal to + // what is needed by alloc, then grow entry + device->FreeDataSpace(dev, best_mem->data); + free_list_.erase(best_mem); + new_mem.type = type_hint; + std::vector shape{int64_t(new_mem.y), int64_t(new_mem.x), 4}; + new_mem.data = device->AllocDataSpace(dev, shape.size(), shape.data(), new_mem.type, + Optional("global.texture")); + e = new_mem; } - - allocated_.push_back(e); - return e.data; } - void Free(void* data) { - Entry e; - if (allocated_.back().data == data) { - // quick path, last allocated. - e = allocated_.back(); - allocated_.pop_back(); - } else { - int index = static_cast(allocated_.size()) - 2; - for (; index >= 0 && allocated_[index].data != data; --index) { - } - ICHECK_GE(index, 0) << "Attempt to free texture that has not been allocated"; - e = allocated_[index]; - allocated_.erase(allocated_.begin() + index); - } - free_list_.push_back(e); + if (e.data == nullptr) { + // create new block + std::vector shape{int64_t(height), int64_t(width), 4}; + e.data = device->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, + Optional("global.texture")); + e.x = width; + e.y = height; + e.type = type_hint; } - // Release all resources immediately - void Release(Device dev, DeviceAPI* device) { - for (auto& e : allocated_) { - device->FreeDataSpace(dev, e.data); - } - for (auto& e : free_list_) { - device->FreeDataSpace(dev, e.data); + allocated_.push_back(e); + return e.data; +} + +void Pool2D::Free(void* data) { + Entry e; + if (allocated_.back().data == data) { + // quick path, last allocated. + e = allocated_.back(); + allocated_.pop_back(); + } else { + int index = static_cast(allocated_.size()) - 2; + for (; index >= 0 && allocated_[index].data != data; --index) { } - allocated_.clear(); - free_list_.clear(); + ICHECK_GE(index, 0) << "Attempt to free texture that has not been allocated"; + e = allocated_[index]; + allocated_.erase(allocated_.begin() + index); } + free_list_.push_back(e); +} - private: - struct Entry { - void* data; - size_t x; - size_t y; - DLDataType type; - }; - std::vector free_list_; - std::vector allocated_; -}; +// Release all resources immediately +void Pool2D::Release(Device dev, DeviceAPI* device) { + for (auto& e : allocated_) { + device->FreeDataSpace(dev, e.data); + } + for (auto& e : free_list_) { + device->FreeDataSpace(dev, e.data); + } + allocated_.clear(); + free_list_.clear(); +} TexturePool::TexturePool(DLDeviceType device_type, DeviceAPI* device) : device_type_(device_type), device_(device) {} @@ -157,7 +156,7 @@ void* TexturePool::AllocTexture(Device dev, size_t width, size_t height, DLDataT array_.resize(dev.device_id + 1, nullptr); } if (array_[dev.device_id] == nullptr) { - array_[dev.device_id] = new Pool(); + array_[dev.device_id] = new Pool2D(); } return array_[dev.device_id]->Alloc(dev, device_, width, height, type_hint); } diff --git a/src/runtime/texture.h b/src/runtime/texture.h index 5f43c8cee8f3f..dc38101f0cd4f 100644 --- a/src/runtime/texture.h +++ b/src/runtime/texture.h @@ -94,6 +94,25 @@ inline bool IsTextureStorage(std::string scope) { return scope.find("texture") != std::string::npos; } +class TVM_DLL Pool2D { + public: + Pool2D() = default; + void* Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint); + void Free(void* data); + // Release all resources immediately + void Release(Device dev, DeviceAPI* device); + + protected: + struct Entry { + void* data; + size_t x; + size_t y; + DLDataType type; + }; + std::vector free_list_; + std::vector allocated_; +}; + /*! * \brief A two dimensional storage pool that recycles temporal workspace * allocations for dynamically allocated texture. See AllocTexture docstring @@ -136,9 +155,8 @@ class TVM_DLL TexturePool { void FreeTexture(Device dev, void* ptr); private: - class Pool; /*! \brief pool of device local array */ - std::vector array_; + std::vector array_; /*! \brief device type this pool support */ DLDeviceType device_type_; /*! \brief The device API */ diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index e6f322885e3a2..4a969dcee8bb9 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -43,6 +43,10 @@ #define TVM_INFO_USE_OPENCL "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_OPENCL_GTEST +#define TVM_INFO_USE_OPENCL_GTEST "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_VULKAN #define TVM_INFO_USE_VULKAN "NOT-FOUND" #endif @@ -286,6 +290,7 @@ TVM_DLL Map GetLibInfo() { {"USE_MSVC_MT", TVM_INFO_USE_MSVC_MT}, {"USE_NNPACK", TVM_INFO_USE_NNPACK}, {"USE_OPENCL", TVM_INFO_USE_OPENCL}, + {"USE_OPENCL_GTEST", TVM_INFO_USE_OPENCL_GTEST}, {"USE_OPENMP", TVM_INFO_USE_OPENMP}, {"USE_PAPI", TVM_INFO_USE_PAPI}, {"USE_PROFILER", TVM_INFO_USE_PROFILER}, diff --git a/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc b/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc new file mode 100644 index 0000000000000..2d3f43ddce6de --- /dev/null +++ b/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "../src/runtime/opencl/opencl_common.h" +#include "../src/runtime/texture.h" + +using namespace tvm::runtime; +using namespace tvm::runtime::cl; + +// PoolWrapper is necessary because in class Pool2D we don't have an access to +// its protected members. In this class we add new methods which allow us to +// get and check internal state of class Pool +class PoolWrapper : public Pool2D { + public: + inline size_t FreeListSize() const { return free_list_.size(); } + inline size_t AllocatedListSize() const { return allocated_.size(); } + inline std::pair FreeListItemSize(size_t idx) const { + return std::make_pair(free_list_[idx].x, free_list_[idx].y); + } + inline std::pair AllocatedListItemSize(size_t idx) const { + return std::make_pair(allocated_[idx].x, allocated_[idx].y); + } +}; + +TEST(OpenCLTexturePool, textures_reallocation_optimal_size) { + OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); + OpenCLThreadEntry* t = workspace->GetThreadEntry(); + PoolWrapper pool; + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 0); + + DLDataType type{kDLFloat, 16, 1}; + void* data1 = pool.Alloc(t->device, workspace, 1024, 768, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 0); + auto item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 768); + + pool.Alloc(t->device, workspace, 64, 12455, type); + EXPECT_EQ(pool.AllocatedListSize(), 2); + EXPECT_EQ(pool.FreeListSize(), 0); + item = pool.AllocatedListItemSize(1); + EXPECT_EQ(item.first, 64); + EXPECT_EQ(item.second, 12455); + + pool.Free(data1); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 64); + EXPECT_EQ(item.second, 12455); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 768); + + pool.Alloc(t->device, workspace, 768, 1024, type); + EXPECT_EQ(pool.AllocatedListSize(), 2); + EXPECT_EQ(pool.FreeListSize(), 0); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 64); + EXPECT_EQ(item.second, 12455); + item = pool.AllocatedListItemSize(1); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 1024); +} + +TEST(OpenCLTexturePool, avoid_reusing_too_big_textures) { + OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); + OpenCLThreadEntry* t = workspace->GetThreadEntry(); + PoolWrapper pool; + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 0); + + DLDataType type{kDLFloat, 16, 1}; + void* data1 = pool.Alloc(t->device, workspace, 12455, 64, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 0); + auto item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 12455); + EXPECT_EQ(item.second, 64); + + pool.Free(data1); + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 12455); + EXPECT_EQ(item.second, 64); + + pool.Alloc(t->device, workspace, 1024, 768, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 12455); + EXPECT_EQ(item.second, 64); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 768); +} + +TEST(OpenCLTexturePool, avoid_reusing_too_small_textures) { + OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); + OpenCLThreadEntry* t = workspace->GetThreadEntry(); + PoolWrapper pool; + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 0); + + DLDataType type{kDLFloat, 16, 1}; + void* data1 = pool.Alloc(t->device, workspace, 1024, 64, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 0); + auto item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 64); + + pool.Free(data1); + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 64); + + pool.Alloc(t->device, workspace, 12544, 64, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 64); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 12544); + EXPECT_EQ(item.second, 64); +} diff --git a/tests/cpp-runtime/opencl/run_gtests.cc b/tests/cpp-runtime/opencl/run_gtests.cc new file mode 100644 index 0000000000000..b16ae3efc74d9 --- /dev/null +++ b/tests/cpp-runtime/opencl/run_gtests.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include + +#include "../src/support/utils.h" + +namespace tvm { +namespace runtime { +namespace cl { + +TVM_REGISTER_GLOBAL("opencl.run_gtests").set_body([](TVMArgs args, TVMRetValue* rv) { + // gtest args are passed into this packed func as a singular string + // split gtest args using delimiter and build argument vector + std::vector parsed_args = tvm::support::Split(args[0], ' '); + std::vector argv; + + // add executable name + argv.push_back(const_cast("opencl_run_gtests")); + + // add parsed arguments + for (int i = 0; i < parsed_args.size(); ++i) { + argv.push_back(const_cast(parsed_args[i].data())); + } + + // end of parsed arguments + argv.push_back(nullptr); + + // set argument count + int argc = argv.size() - 1; + + // initialize gtest with arguments and run + ::testing::InitGoogleTest(&argc, argv.data()); + *rv = RUN_ALL_TESTS(); +}); + +} // namespace cl +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_opencl/conftest.py b/tests/python/contrib/test_opencl/conftest.py new file mode 100644 index 0000000000000..0a8b9e1c631f0 --- /dev/null +++ b/tests/python/contrib/test_opencl/conftest.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" OpenCL testing fixtures used to deduce testing argument + values from testing parameters """ + + +import pytest + +import tvm +import tvm.testing + +pytest_plugins = [ + "tvm.contrib.hexagon.pytest_plugin", +] diff --git a/tests/python/contrib/test_opencl/test_run_gtests.py b/tests/python/contrib/test_opencl/test_run_gtests.py new file mode 100644 index 0000000000000..4afcf7ee8d660 --- /dev/null +++ b/tests/python/contrib/test_opencl/test_run_gtests.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import pytest +import numpy as np + +import tvm +from tvm import rpc + + +# use pytest -sv to observe gtest output +# use --gtest_args to pass arguments to gtest +# for example to run all "foo" tests twice and observe gtest output run +# pytest -sv --gtests_args="--gtest_filter=*foo* --gtest_repeat=2" +@tvm.testing.requires_opencl +def test_run_gtests(gtest_args): + if ( + "TVM_TRACKER_HOST" in os.environ + and "TVM_TRACKER_PORT" in os.environ + and "TVM_TRACKER_KEY" in os.environ + ): + rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] + rpc_tracker_port = os.environ["TVM_TRACKER_PORT"] + rpc_tracker_port = int(rpc_tracker_port) + rpc_key = os.environ["TVM_TRACKER_KEY"] + tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) + rpc_connection = tracker.request(rpc_key, priority=0, session_timeout=600) + else: + rpc_connection = rpc.LocalSession() + + try: + func = rpc_connection.get_function("opencl.run_gtests") + except: + print( + "This test requires TVM Runtime to be built with a OpenCL gtest version using OpenCL API cmake flag -DUSE_OPENCL_GTEST=/path/to/opencl/googletest/gtest" + ) + raise + + gtest_error_code = func(gtest_args) + np.testing.assert_equal(gtest_error_code, 0) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index d36da51c8f713..89f68dacbd3ff 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import re import tvm import numpy as np from tvm import relay @@ -392,3 +393,45 @@ def test_conv2d_yolov3_v2_nchw_3c(): } build_run_compare(mod, params, {"data": input_shape}, dtype, target) + + +@tvm.testing.requires_opencl +def test_conv2d_vgg16_winograd_4d(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 512, 28, 28) + filter_shape = (512, 512, 3, 3) + bias_shape = (1, 512, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + channels=512, + kernel_size=[3, 3], + out_dtype=dtype, + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + matches = re.findall("winograd", graph) + assert len(matches) > 0 diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py b/tests/python/relay/test_conv2d_nhwc_texture.py index a02b7cabbef62..96227ca551cf9 100644 --- a/tests/python/relay/test_conv2d_nhwc_texture.py +++ b/tests/python/relay/test_conv2d_nhwc_texture.py @@ -16,6 +16,7 @@ # under the License. import os +import re import tvm import numpy as np from tvm import relay @@ -554,3 +555,45 @@ def test_conv2d_yolov3_v2_nhwc_3c(): } build_run_compare(mod, params, {"data": input_shape}, dtype, target) + + +@tvm.testing.requires_opencl +def test_conv2d_vgg16_winograd_4d(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 28, 28, 512) + filter_shape = (3, 3, 512, 512) + bias_shape = (1, 1, 1, 512) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NHWC", + kernel_layout="HWIO", + padding=[1, 1, 1, 1], + channels=512, + kernel_size=[3, 3], + out_dtype=dtype, + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + matches = re.findall("winograd", graph) + assert len(matches) > 0 diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py index 11abce3bfaa0a..3bb4a6ada4ecc 100644 --- a/tests/python/relay/utils/adreno_utils.py +++ b/tests/python/relay/utils/adreno_utils.py @@ -105,6 +105,7 @@ def build_run_compare( # print(index, output[index], x) np.testing.assert_allclose(output, ref_output, rtol=1e-1, atol=1e-1) + return graph def gpu_preprocess(tvm_mod): From 236eea0f49b4ca9a30e99d54f2ceb7ee3ef836f7 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Thu, 9 Jun 2022 10:19:31 +0100 Subject: [PATCH 076/181] [CMSIS-NN] Removed redudant arguments to CMSIS-NN wrapper function (#11431) Removed input_scale and filter_scale from CMSIS-NN wrapper function. These are not needed by CMSIS-NN API which gets called from the generated C wrapper function for Conv2D. --- .../backend/contrib/cmsisnn/relay_to_tir.cc | 29 +++++- .../contrib/test_cmsisnn/test_conv2d.py | 96 ++++++++++++++++++- 2 files changed, 121 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index dc5537ee905d8..524735caa9d6a 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -141,18 +141,24 @@ class RelayToTIRVisitor : public MixedModeMutator { // %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5, // %output_scale_scalar, %output_zero_point_scalar) // clip(%3, a_min=%min_scalar, a_max=%max_scalar) + // Position of scales in the global function for Conv2D + const int filter_scale_pos = 3; + const int input_scale_pos = bias_add_call ? 5 : 4; BufferCreator buffer_creator; tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8)); tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8)); tir::Var multiplier = buffer_creator.CreateBufferVar("multiplier", DataType::Handle(32)); - tir::Var filter_scale = buffer_creator.CreateBufferVar("filter_scale", DataType::Handle(32)); if (bias_add_call) { buffer_creator.CreateBufferVar("bias", DataType::Handle(32)); } - tir::Var input_scale = buffer_creator.CreateBufferVar("input_scale", DataType::Handle(32)); tir::Var shift = buffer_creator.CreateBufferVar("shift", DataType::Handle(32)); tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); + // Relay function contains input_scale and filter_scale as function parameters at the following + // locations in the global partitioned function for Conv2D + skip_call_args_.insert(filter_scale_pos); + skip_call_args_.insert(input_scale_pos); + // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 @@ -742,11 +748,25 @@ class RelayToTIRVisitor : public MixedModeMutator { GetRef(func)); } + // Drop out the redundant arguments, and the arg_types from the global function call Array args; + Array arg_types; + auto* func_type = new_global_var->checked_type_.as(); + int arg_id = -1; for (const auto& arg : call->args) { + ++arg_id; + if (std::find(skip_call_args_.begin(), skip_call_args_.end(), arg_id) != + skip_call_args_.end()) { + continue; + } args.push_back(VisitExpr(arg)); + arg_types.push_back(func_type->arg_types[arg_id]); } - + if (arg_types.size() != func_type->arg_types.size()) { + new_global_var->checked_type_ = + FuncType(arg_types, func_type->ret_type, {}, func_type->type_constraints); + } + skip_call_args_.clear(); return Call(new_global_var, args, call->attrs, call->type_args, call->span); } } @@ -757,7 +777,10 @@ class RelayToTIRVisitor : public MixedModeMutator { static constexpr int32_t kScaledDiffIntegerBits = 5; static constexpr int32_t kInputBits = 5; static constexpr double kBeta = 1.0; + /*! \brief Unique id for context buffer needed by CMSIS-NN layers. */ int32_t context_buffer_id_; + /*! \brief Skip arguments in the call to global partitioned function. */ + std::unordered_set skip_call_args_; IRModule ir_module_; Target target_; }; diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 439a3ec39c9a7..90261e540a7d6 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -23,7 +23,7 @@ from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run +from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_models, compile_and_run from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from utils import ( @@ -119,6 +119,100 @@ def make_model( return last_op, params +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale, out_channels", + [(10, 0.0128, [0.11, 0.22], 2)], +) +def test_conv2d_number_primfunc_args( + padding, + enable_bias, + input_zero_point, + input_scale, + kernel_scale, + out_channels, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + ifm_shape = (1, 64, 100, 4) + kernel_size = (3, 3) + strides = (1, 1) + dilation = (1, 1) + dtype = "int8" + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 + in_min, in_max = get_range_for_dtype_str(dtype) + relu_type = "RELU" + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # compile the model + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + + compiled_models = compile_models( + AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params), + interface_api, + use_unpacked_api, + ) + + # validate number of TIR primfunc args + expected_num_params = 6 if enable_bias else 5 + cmsisnn_tir_mod = None + for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items(): + if "cmsis-nn" == target.kind.name: + cmsisnn_tir_mod = mod + + cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"] + assert ( + len(cmsisnn_func.params) == expected_num_params + ), "Generated unexpected number of function arguments" + + @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("relu_type", ["RELU"]) From d8678a6a9aa7962b658efb603e27d83ea7737a02 Mon Sep 17 00:00:00 2001 From: FranckQC <89943638+FranckQC@users.noreply.github.com> Date: Thu, 9 Jun 2022 11:32:15 -0500 Subject: [PATCH 077/181] [TIR] CSE pass : Restrict the equivalence to be decided by a normal form - avoids comparison of terms (#11574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test. However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**. That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them. **This PR essentially does the following:** - When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**. That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc). It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs. - Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms. As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that. To run it ON, simply do: `with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):` before calling `build()` - When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in https://github.com/apache/tvm/pull/10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc. - Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in https://github.com/apache/tvm/pull/10544 but which did not make it through. - Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from https://github.com/apache/tvm/pull/10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them). --- include/tvm/tir/transform.h | 3 +- python/tvm/tir/transform/transform.py | 4 +- src/driver/driver_api.cc | 6 +- src/tir/transforms/common_subexpr_elim.cc | 96 +++++-- src/tir/transforms/common_subexpr_elim.h | 8 +- .../transforms/common_subexpr_elim_tools.cc | 145 +++++++--- .../transforms/common_subexpr_elim_tools.h | 10 +- .../test_tir_transform_common_subexpr_elim.py | 260 ++++++++++++++---- 8 files changed, 409 insertions(+), 123 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 24c3cfa78f721..4612d5ad3feac 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -470,9 +470,10 @@ TVM_DLL Pass LowerVtcmAlloc(); * \brief Implements a Common Subexpression Elimination (CSE) for TIR * which introduces let-in bindings for duplicated sub-expressions. * \param enable_cse_tir Whether common subexpression elimination is enabled. + * \param identify_equiv_terms Whether equivalent terms should be identified. * \return The pass. */ -TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true); +TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false); /*! * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 802fdc576c41f..1bed29c560fc9 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -324,7 +324,7 @@ def BF16TypeLowering(): return _ffi_api.BF16TypeLowering() # type: ignore -def CommonSubexprElimTIR(enable_cse_tir: bool = True): +def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False): """Replace redundant computations by new variables. Returns @@ -332,7 +332,7 @@ def CommonSubexprElimTIR(enable_cse_tir: bool = True): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.CommonSubexprElimTIR(enable_cse_tir) # type: ignore + return _ffi_api.CommonSubexprElimTIR(enable_cse_tir, identify_equiv_terms) # type: ignore def RewriteUnsafeSelect(): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7df1a844acc2b..7706f229c9ed3 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); @@ -198,6 +199,8 @@ Array CreatePassList(bool disable_loop_partition) { bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); bool disable_cse_tir = pass_ctx->GetConfig("tir.disable_cse_tir", Bool(false)).value(); + bool enable_equiv_terms_in_cse_tir = + pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -289,7 +292,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InstrumentBoundCheckers()); } - pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir)); + pass_list.push_back( + tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir)); return pass_list; } diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index d43b30d17be00..290f920e3fc07 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -60,7 +60,7 @@ namespace tir { to collect them for the CSE pass, but we also won't even want to collect computations that contain them. The reason is that reusing such computations would change the semantics of the program, - and therefore before doing any introduction of variable or any reuse of already introduced + and therefore before doing any introduction of var or any reuse of already introduced variables, we will make sure that the computation being considered is not forbidden, and that it does not even contain a forbidden computation. * \param expr The expression to check @@ -120,6 +120,42 @@ bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp return true; } +/*! + * \brief Implements an order on pairs (expression,frequency). First attempts to compare them + using the size of the expression. If it is the same, decides something else still + deterministic. + * \param a The first pair + * \param b The second pair + * \return A boolean telling if the first pair `a` comes before the second pair `b` + * \note We need this order to be deterministic in order to have a fully deterministic pass, + * as we will deal with elements that are coming from a hashtable, but the order in which + * they appeared in the hashtable was based on some runtime addresses, so it can potentially + * change with every execution. + */ +bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair a, + std::pair b) { + size_t a_size = CalculateExprComplexity(a.first); + size_t b_size = CalculateExprComplexity(b.first); + + // Criteria 1 - Size of the expression comes first + // `a` comes before `b` if the size of `a` is bigger + if (a_size > b_size) { + return true; + } + // `a` does NOT come before `b` if the size of `b` is bigger + if (b_size > a_size) { + return false; + } + + // Criteria 2 - If they had the same size, use the lexicographic order as a last resort + // as we need a deterministic order + std::stringstream a_stream; + std::stringstream b_stream; + a_stream << a.first; + b_stream << b.first; + return (a_stream.str().compare(b_stream.str()) < 0); +} + /*! * \brief Generates a new fresh variable, whose name will be cse_var_i. * \param type_annotation The type of the new variable to generate @@ -166,10 +202,12 @@ int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; } of the function being analyzed * \return A new statement where CSE has been performed */ -Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) { +Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init, + bool identify_equiv_terms) { // As this function is being called for each PrimFunc definition, we create a new instance // for the one we are having now. - CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init); + CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init, + identify_equiv_terms); return common_subexpression_eliminator.VisitStmt(stmt); } @@ -179,8 +217,9 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& formal parameters of the function that will be analyzed */ CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt, - const Context& context_init) - : initial_body_(stmt), context_(context_init) {} + const Context& context_init, + bool identify_equiv_terms) + : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {} /*! * \brief The method which overrides the generic dispatcher of StmtExprMutator. @@ -200,28 +239,28 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { // Transform the hashtable of *syntactic* eligible computations into a vector of pairs // containing *semantic* entities, i.e. where equivalent computations are merged. std::vector> semantic_comp_done_by_expr = - SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr); + SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_); // Sort the vector of semantic entities by decreasing size std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), - [](std::pair a, std::pair b) { - return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first)); - }); + OrderOnExprAndFrequency); // For each computation done (considering them from biggest to smallest) for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { std::pair& computation_and_nb = semantic_comp_done_by_expr[i]; + bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this" + // The predicate later used (when doing replacements) to select expressions that are // equivalent to the current computation (`computation_and_nb.first`) std::function predicate_selector = - [computation_and_nb](const PrimExpr& current_expr) { + [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) { // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check // that `current_expr` is an eligible computation even if we know that // `computation_and_nb.first` is eligible by construction, in case that one day the // equivalence relation would not preserve the eligibility any more (even though that // would probably be a very weird equivalence). - return (EquivalentTerms(current_expr, computation_and_nb.first) && + return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) && IsEligibleComputation(current_expr)); }; @@ -229,10 +268,11 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { // equivalent to `computation_and_nb.first` auto it_on_var = std::find_if( context_.begin(), context_.end(), - [computation_and_nb](const std::pair& var_and_value) { + [computation_and_nb, ident_equiv_terms](const std::pair& var_and_value) { // Note : safe to call value() as we check has_value() just before return (var_and_value.second.has_value() && - EquivalentTerms(var_and_value.second.value(), computation_and_nb.first)); + EquivalentTerms(var_and_value.second.value(), computation_and_nb.first, + ident_equiv_terms)); }); // Case where we have a perfectly equivalent computation already available in a variable @@ -298,7 +338,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by // decreasing size/complexity), and it will only insert at locations > i as the // direct subexprs are necessarily smaller than the current computation. - InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs); + InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, + identify_equiv_terms_); } } // Note : we do not remove the current element, as we never look back in the local vector @@ -378,28 +419,28 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { // Transform the hashtable of *syntactic* eligible computations into a vector of pairs // containing *semantic* entities, i.e. where equivalent computations are merged. std::vector> semantic_comp_done_by_stmt = - SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt); + SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_); // Sort the vector of semantic entities by decreasing size std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), - [](std::pair a, std::pair b) { - return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first)); - }); + OrderOnExprAndFrequency); // For each computation done (considering them from biggest to smallest) for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { std::pair& computation_and_nb = semantic_comp_done_by_stmt[i]; + bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this" + // The predicate later used (when doing replacements) to select expressions that are // equivalent to the current computation (`computation_and_nb.first`) std::function predicate_selector = - [computation_and_nb](const PrimExpr& current_expr) { + [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) { // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check // that `current_expr` is an eligible computation even if we know that // `computation_and_nb.first` is eligible by construction, in case that one day the // equivalence relation would not preserve the eligibility any more (even though that // would probably be a very weird equivalence). - return (EquivalentTerms(current_expr, computation_and_nb.first) && + return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) && IsEligibleComputation(current_expr)); }; @@ -407,10 +448,11 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { // equivalent to `computation_and_nb.first` auto it_on_var = std::find_if( context_.begin(), context_.end(), - [computation_and_nb](const std::pair& var_and_value) { + [computation_and_nb, ident_equiv_terms](const std::pair& var_and_value) { // Note : safe to call value() as we check has_value() just before return (var_and_value.second.has_value() && - EquivalentTerms(var_and_value.second.value(), computation_and_nb.first)); + EquivalentTerms(var_and_value.second.value(), computation_and_nb.first, + ident_equiv_terms)); }); // Case where we have a perfectly equivalent computation already available in a variable @@ -477,7 +519,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { // The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by // decreasing size/complexity), and it will only insert at locations > i as the // direct subexprs are necessarily smaller than the current computation. - InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs); + InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs, + identify_equiv_terms_); } } // Note : we do not remove the current element, as we never look back in the local vector @@ -587,8 +630,8 @@ namespace transform { * \brief The function which returns the pass for the Common Subexpression Elimination. * \return The pass for performing CSE. */ -Pass CommonSubexprElimTIR(bool enable_cse_tir) { - auto pass_func = [enable_cse_tir](PrimFunc f, IRModule m, PassContext ctx) { +Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) { + auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) { if (enable_cse_tir) { auto* n = f.CopyOnWrite(); Context context_init; @@ -603,7 +646,8 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir) { // Do the Common Subexpression Elimination on the body of the function, with the initial // context that we have prepared - n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init); + n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init, + identify_equiv_terms); } return f; diff --git a/src/tir/transforms/common_subexpr_elim.h b/src/tir/transforms/common_subexpr_elim.h index 484d93c769822..5c14caf1a6e36 100644 --- a/src/tir/transforms/common_subexpr_elim.h +++ b/src/tir/transforms/common_subexpr_elim.h @@ -55,7 +55,7 @@ using Context = std::vector>; class CommonSubexpressionEliminator : public StmtExprMutator { public: // Toplevel (static) function - static Stmt PerformCSE(const Stmt& stmt, const Context& context_init); + static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms); PrimExpr VisitExpr(const PrimExpr& expr) override; Stmt VisitStmt(const Stmt& stmt) override; @@ -64,7 +64,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator { protected: // Constructor - CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init); + CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init, + bool identify_equiv_terms); PrimExpr VisitExpr_(const LetNode* op) override; @@ -77,9 +78,12 @@ class CommonSubexpressionEliminator : public StmtExprMutator { int num_last_try_ = 0; // Number of the last variable tried int nb_var_ = 0; // Number of variables introduced by the CSE pass + bool identify_equiv_terms_ = false; + static bool ForbiddenComputation(const PrimExpr& expr); static bool IsEligibleComputation(const PrimExpr& expr); static bool CanContainEligibleComputations(const PrimExpr& expr); + static bool OrderOnExprAndFrequency(std::pair a, std::pair b); Var GenerateNewVar(DataType type_annotation); }; diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index d39d211ba1824..b5b1bfccdf4ac 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -25,7 +25,8 @@ #include "common_subexpr_elim_tools.h" -#include // For the class Pass and the class PassContext +#include // For the arith::Analyzer::Simplify() method simplifying terms +#include // For the class Pass and the class PassContext #include #include // For the ExprDeepEqual analysis #include @@ -720,14 +721,42 @@ bool EqualTerms(const PrimExpr& a, const PrimExpr& b) { return deep_equal_(a, b); } +/*! + * \brief Normalization function of a term, use to decide the equivalence relation of interest + * \param expr The expression to normalize + * \param do_normalization Whether we want the function to actually do normalization + * \note This function can be customized + */ +PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization) { + if (do_normalization) { + // Customize here! + // We could decide to normalize terms in a way that identifies them modulo commutativity + // (like x+y and y+x), or modulo associativity (like (x+y)+z and x+(y+z)), etc. + // For that, a normalization procedure (or an incomplete "pseudo-normalization" like + // arith::Analyzer::Simplify) will be used. + + // One possible customization: + // Here is just an attempt to do more commonings by using the pseudo-normalization function + // offered by arith::Analyzer::Simplify(). "pseudo" because while it is correct (i.e. + // the simplification is indeed equivalent to the original term), it is incomplete (i.e. + // the returned term is not guaranteed to be a normal form). + arith::Analyzer analyzer; + return analyzer.Simplify(expr); + } else { + // If `do_normalization` is false, the equivalence relation just checks the syntactic equality, + // so the normalization is just the identity function. + return expr; + } +} + /*! * \brief Decides if two terms are equivalent semantically */ -bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { - // For now, we just check the syntactic equality, but that could later become a semantic test, - // for instance identifying computations modulo commutativity (like x+y and y+x), or modulo - // associativity (like (x+y)+z and x+(y+z)), etc. - return EqualTerms(a, b); +bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool identify_equiv_terms) { + // We restrict the equivalence to be decidable by a normalization procedure that is used to + // normalize both sides, and to then compare the normal forms with the strict syntactical + // equality + return EqualTerms(NormalizeTerm(a, identify_equiv_terms), NormalizeTerm(b, identify_equiv_terms)); } /*! @@ -739,21 +768,52 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { \note This function is needed because the advantage of the hashtable was the constant lookup. But in order to have this constant lookup, we could not collapse semantically equivalent computations. + Attention, the pairs returned are deterministic and will always be the same (as the same + canonical representant will always be chosen for a given class of equivalence), but the + order in which these pairs appear in the result is not deterministic, as it is based on + the order in which we found items in the "normalized hashtable" `norm_table`). The caller + is expected to sort the result anyway. */ std::vector> SyntacticToSemanticComputations( - const ComputationTable& table) { + const ComputationTable& table, bool identify_equiv_terms) { std::vector> result; - // table.size() is an upper-bound of the number of elements in the resulting vector, - // as we might merge semantically equivalent computations. - // We do this reservation even if it might reserve slightly more space than is needed in the end - result.reserve(table.size()); + // If we do NOT identify equivalent terms, then we simply need to transform the input hashtable + // into a vector, without doing anything else. + if (!identify_equiv_terms) { + // The result will contain exactly as many elements as the input `table` has + result.reserve(table.size()); + for (const auto& elem : table) { + result.push_back(elem); + } - // Traverse through map in a sorted order on keys to maintain deterministic behavior - // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering - std::vector> sorted_map_items(table.begin(), table.end()); + return result; + } - sort(sorted_map_items.begin(), sorted_map_items.end(), + // Otherwise, in order to identify equivalent terms, we will go through a table `norm_table` + // where normal forms are the keys., and use it to efficiently merge equivalent terms. + + // In order to produce the result (a vector of semantical entities), the input table will be + // normalized. This normalized table will keep the count for each set of equivalent terms + // (i.e. each equivalence class), together with a term that did appear in this equivalence class + // (in practice, the first term of the equivalence class that was encoutered). + std::unordered_map, StructuralHash, ExprDeepEqual> + norm_table; + + // In order to avoid frequent rehashing if the norm_table becomes big, we immediately ask for + // enough space to store the amount of elements that the input table has, as it's clearly an + // upper bound (in the worst case, each element is its own representant, and there is as many + // equivalence classes as there are elements) + norm_table.reserve(table.size()); + + // Transform the input hashtable to a vector and sort it according to some order, as we will be + // iterating through its items soon, and the order of appearance will be used to determine the + // individual representant for each class of equivalence, which we want to be deterministic + // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run by y+x). + std::vector> sorted_items_of_table(table.begin(), table.end()); + + // We do the ordering by comparing the string repr of each expr to get a determinstic ordering + sort(sorted_items_of_table.begin(), sorted_items_of_table.end(), [](std::pair a, std::pair b) { std::stringstream a_stream; std::stringstream b_stream; @@ -762,21 +822,40 @@ std::vector> SyntacticToSemanticComputations( return a_stream.str().compare(b_stream.str()) < 0; }); - // For each element in the hashtable - for (auto elem : sorted_map_items) { - // We try to see if a semantically equivalent term is already in the resulting vector - auto it_found = std::find_if(result.begin(), result.end(), - [elem](std::pair already_seen) { - return EquivalentTerms(already_seen.first, elem.first); - }); - // And if so, we increase (by `elem.second`) its count - if (it_found != result.end()) { - it_found->second += elem.second; + for (const auto& elem : sorted_items_of_table) { + PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms); + // If the normalized term is not already a key in the normalized table + auto it_found = norm_table.find(norm_elem); + if (it_found == norm_table.end()) { + // Then we add the mapping `norm_elem` -> (`elem`.first, `elem`.second) to the norm table + // (i.e. `norm_elem` has been seen `elem`.second many times so far, and the chosen element + // to represent the equivalence class will be `elem`.first as it's the first element of the + // class that we see) + norm_table[norm_elem] = elem; } else { - // If we could not find a semantically equivalent term in the resulting vector, we add it - result.push_back(elem); + // Otherwise, it's not the first time we see a term in this equivalence class, so we just + // increase the count of this equivalence class as we now have `elem`.second additional items + // coming to the equivalence class. + it_found->second.second += elem.second; } } + + // norm_table.size() is the number of equivalence class that we have built, so it's exactly the + // number of items that we will return in the vector of semantical entities + result.reserve(norm_table.size()); + + // Transform the intermediate hashtable `norm_table` into a vector, forgetting the keys, + // (which are the normal forms), as they won't be used as the canonical representants (which are + // instead the first element of each class that is effectively seen) + // Careful : the pairs will never change (the canonical represantants chosen will always be the + // same), but the order in which the pairs are produced can vary as we are iterating through the + // hashtable `norm_table`. It is not an issue as the called will be sorting the result anyway. + std::unordered_map, StructuralHash, + ExprDeepEqual>::const_iterator it_norm_table; + for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end(); ++it_norm_table) { + result.push_back(it_norm_table->second); + } + return result; } @@ -822,17 +901,19 @@ void InsertElemToSortedSemanticComputations(std::vector>* sorted_vec, - const std::vector& vec_to_add) { + const std::vector& vec_to_add, + bool identify_equiv_terms) { if (sorted_vec == nullptr) { return; } for (auto elem_to_add : vec_to_add) { // See if the current element to add (or an equivalent one) is already present // in the sorted vector - auto it_found = std::find_if(sorted_vec->begin(), sorted_vec->end(), - [elem_to_add](std::pair elem) { - return EquivalentTerms(elem.first, elem_to_add); - }); + auto it_found = + std::find_if(sorted_vec->begin(), sorted_vec->end(), + [elem_to_add, identify_equiv_terms](std::pair elem) { + return EquivalentTerms(elem.first, elem_to_add, identify_equiv_terms); + }); // If we found `elem_to_add` (or an equivalent expression) already in sorted_vec if (it_found != sorted_vec->end()) { diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index a590cde69fafc..fcd29fddc0a17 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -180,9 +180,12 @@ void PrintComputationTable(const ComputationTable& table); using MaybeValue = dmlc::optional; bool EqualTerms(const PrimExpr& a, const PrimExpr& b); -bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b); +// Used for deciding the (decidable) equivalence relation +PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization); +// The equivalence relation, which is the syntactical equality when `identify_equiv_terms` is false +bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool identify_equiv_terms); std::vector> SyntacticToSemanticComputations( - const ComputationTable& table); + const ComputationTable& table, bool identify_equiv_terms); bool PredicateIntroVarForComputation(const PrimExpr& computation, size_t nb_times_seen); // Polymorphic (functional) map on a vector, which builds a news vector with the same number of @@ -209,7 +212,8 @@ template std::vector VectorMap(const std::vector void InsertElemToSortedSemanticComputations(std::vector>* sorted_vec, const std::pair& pair); void InsertVectorToSortedSemanticComputations(std::vector>* sorted_vec, - const std::vector& vec_to_add); + const std::vector& vec_to_add, + bool identify_equiv_terms); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index c12e27a46e3f2..a546c16a648ec 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -17,12 +17,16 @@ import hashlib import tvm -from tvm import te +from tvm import auto_scheduler, te, topi from tvm.ir.base import save_json from tvm.ir.module import IRModule +from tvm.script import tir as T - -# A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels +# ----------------------------------------------------- +# Basic test for the expected Behavior of the CSE pass +# ----------------------------------------------------- +# A test program which gives the opportunity for the CSE pass to introduce two new variables, +# at two different levels def test_cse(): z1 = te.var("z1") z2 = te.var("z2") @@ -70,9 +74,9 @@ def test_cse(): ), ), ) - # This test program gives the opportunity to introduce two new variables, at two different levels - # and to perform replacements in the value of "a" and "b", using these new variables - # We will check all of that underneath and more, making also sure that nothing else has been changed + # This test program gives the opportunity to introduce two new variables, at two different + # levels and to perform replacements in the value of "a" and "b", using these new variables. + # We will check all of that underneath and more, making also sure that nothing else has changed mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) @@ -138,52 +142,14 @@ def test_cse(): assert isinstance(body.body, tvm.tir.BufferStore) -def test_deterministic_cse(): - import random - - """Test deterministic allocation of CSE vars - - We expect something like - - result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3) - --> - cse_var_3 = (x + 1) - cse_var_2 = (x + 2) - cse_var_1 = (x + 3) - result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1 - """ - NUM_TERMS = 10 - REPEATS = 10 - - x = te.var("x") - result = te.var("result") - - offsets = sorted([i + 1 for i in range(NUM_TERMS)]) - inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)] - inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)] - - expression = x - for add in inc1 + inc2: - expression = expression + add - let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result)) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt)) - - initial_hash = None - for _ in range(REPEATS): - body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"] - - # Hash and ensure serialize json is the same every time - json_val = save_json(body) - json_hash = hashlib.sha256(json_val.encode()).hexdigest() - - if initial_hash is None: - initial_hash = json_hash - assert json_hash == initial_hash - - -# First specific test for if nodes : Some duplicated computations appear only in one branch (here the Then branch), not in both branches. -# In this case, the CSE pass should introduce the redundant computation at the top if the Then branch, not before the whole If -# (otherwise that would lead to some computations being computed for nothing when it is the Else branch that is executed). +# ----------------------------------------------------- +# Tests related to If nodes +# ----------------------------------------------------- +# First specific test for if nodes : Some duplicated computations appear only in one branch (here +# the Then branch), not in both branches. +# In this case, the CSE pass should introduce the redundant computation at the top of the Then +# branch, not before the whole If (otherwise that would lead to some computations being computed +# for nothing when it is the Else branch that is executed). def test_cse_ifNode_1(): b = te.var("b") i1 = te.var("i1") @@ -237,9 +203,9 @@ def test_cse_ifNode_1(): assert tvm.ir.structural_equal(body.value, y + z) -# Second test for if nodes : Some duplicated computations appear in both the Then and the Else branch. -# In this case, the CSE pass should introduce the redundant computation before the whole If node, because -# regardless of the execution path, it is going to be computed. +# Second test for if nodes : Some duplicated computations appear in both the Then and Else branch. +# In this case, the CSE pass should introduce the redundant computation before the whole If node, +# because regardless of the execution path, it is going to be computed. def test_cse_ifNode_2(): b = te.var("b") i1 = te.var("i1") @@ -265,7 +231,7 @@ def test_cse_ifNode_2(): b, tvm.tir.SeqStmt( [ - tvm.tir.BufferStore(buffer, y + z, [i1]), # (y+z) is present in the Then branch + tvm.tir.BufferStore(buffer, y + z, [i1]), # (y+z) is present in Then branch tvm.tir.BufferStore(buffer, y, [i2]), ] ), @@ -288,9 +254,11 @@ def test_cse_ifNode_2(): assert tvm.ir.structural_equal(body.value, y + z) +# ------------------------------------------------------------------------------------------------- # Test commoning in cascade : after having introduced a big exp ((x+y)+z) into a new variable, # it will become possible to do another commoning for (x+y) which appears both in the new variable # and in the rest of the program. +# ------------------------------------------------------------------------------------------------- def test_cse_cascade(): i1 = te.var("i1") i2 = te.var("i2") @@ -353,8 +321,188 @@ def test_cse_cascade(): assert tvm.ir.structural_equal(store3.value, cse_var_2) +# ----------------------------------------------------------------------------------------- +# A test which ensures that we don't perform normalizations outside of introduced variables +# ----------------------------------------------------------------------------------------- +def test_no_normalization_without_commoning(): + x = te.var("x") + y = te.var("y") + z = te.var("z") + a = te.var("a") + # Test prog : + # let a = x + (y + z) in a + body = tvm.tir.LetStmt(a, x + (y + z), tvm.tir.Evaluate(a)) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x, y, z], body)) + body = tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod) + + tvm.transform.PrintIR()(body) + + body = body["main"].body # Gets the body of the main, i.e. the full statement + + assert body.var.name == "a" + assert tvm.ir.structural_equal(body.value, x + (y + z)) + + +# ------------------------------------------------- +# Part for testing the commoning with equivalences +# ------------------------------------------------- +@T.prim_func +def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = x * (y + z) + B[i2] = x * y + x * z + + +@T.prim_func +def func_distributivity_expected( + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + cse_var_1 = T.var("int32") + with T.let(cse_var_1, x * y + x * z): + B[i1] = cse_var_1 + B[i2] = cse_var_1 + + +@T.prim_func +def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = (x + y) + z + B[i2] = x + (y + z) + + +@T.prim_func +def func_associativity_expected( + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + cse_var_1 = T.var("int32") + with T.let(cse_var_1, (x + y) + z): + B[i1] = cse_var_1 + B[i2] = cse_var_1 + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + body = tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod) + tvm.transform.PrintIR()(body) + tvm.ir.assert_structural_equal(body["main"], transformed) + + +def test_semantic_equiv_distributivity(): + _check(func_distributivity, func_distributivity_expected) + + +def test_semantic_equiv_associativity(): + _check(func_associativity, func_associativity_expected) + + +# ----------------------------------------------------- +# Tests that verify the determinism of the pass +# ----------------------------------------------------- +def test_deterministic_cse(): + import random + + """Test deterministic allocation of CSE vars + + We expect something like + + result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3) + --> + cse_var_3 = (x + 1) + cse_var_2 = (x + 2) + cse_var_1 = (x + 3) + result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1 + """ + NUM_TERMS = 10 + REPEATS = 10 + + x = te.var("x") + result = te.var("result") + + offsets = sorted([i + 1 for i in range(NUM_TERMS)]) + inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)] + inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)] + + expression = x + for add in inc1 + inc2: + expression = expression + add + let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt)) + + initial_hash = None + for _ in range(REPEATS): + body = tvm.tir.transform.CommonSubexprElimTIR()(mod) + + body = body["main"] + + # Hash and ensure serialize json is the same every time + json_val = save_json(body) + json_hash = hashlib.sha256(json_val.encode()).hexdigest() + + if initial_hash is None: + initial_hash = json_hash + assert json_hash == initial_hash + + +# Needed for the second test on determinism +LOG_LINE = '{"i": [["[\\"conv2d_layer\\", 1, 7, 7, 512, 512, 3, 3, [1, 1], [1, 1]]", \ + "llvm -keys=cpu -link-params=0 -mcpu=broadwell -num-cores=2", \ + [8, 64, 64, 0, 0, 0, 0, 0], "", 1, []], [[], [["CI", 5], \ + ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 512, [1, 32, 16], 1], \ + ["SP", 3, 8, 7, [7, 1, 1], 1], ["SP", 3, 12, 7, [1, 1, 1], 1], \ + ["SP", 3, 16, 512, [1], 1], ["SP", 3, 18, 3, [1], 1], ["SP", 3, 20, 3, [3], 1], \ + ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, \ + 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP", 6, 6, 3, 2], \ + ["FSP", 6, 9, 4, 2], ["RE", 6, [0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11]], \ + ["CA", 3, 6, 7], ["CA", 1, 6, 5], ["FU", 6, [0, 1, 2, 3, 4, 5]], ["AN", 6, 0, 3], \ + ["PR", 3, 0, "auto_unroll_max_step$512"], ["AN", 1, 3, 2], ["AN", 3, 21, 2], \ + ["AN", 6, 6, 2]]]], "r": [[0.0331129], 0, 0.900362, 1647464342], "v": "v0.6"}\n' + +# The workload associated with the log +@auto_scheduler.register_workload +def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding): + data = te.placeholder((N, CI, H, W), name="data") + kernel = te.placeholder((CO, CI, KH, KW), name="kernel") + bias = te.placeholder((1, CO, 1, 1), name="bias") + conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32") + out = topi.nn.relu(conv + bias) + return [data, kernel, bias, out] + + +def test_deterministic_cse_2(): + inp, inr = auto_scheduler.measure_record.load_record_from_string(LOG_LINE) + inp = auto_scheduler.measure.recover_measure_input(inp, rebuild_state=True) + + initial_hash = None + + for _ in range(10): + sch, args = inp.task.compute_dag.apply_steps_from_state(inp.state) + ir_module = tvm.lower(sch, args) + primfunc = ir_module["main"] + json_str = save_json(primfunc) + new_hash = hashlib.sha256(json_str.encode("utf-8")).hexdigest() + # Make sure that all the hashes are going to be the same + if initial_hash is None: + initial_hash = new_hash + assert new_hash == initial_hash + + if __name__ == "__main__": + # Basic test: test_cse() + # Tests related to If nodes: test_cse_ifNode_1() test_cse_ifNode_2() + # Test performing a commoning on a commoning: test_cse_cascade() + # Test that verifies that the input program itself is not being normalized by the pass: + test_no_normalization_without_commoning() + # Tests that turn on the equivalence of terms and verify the commoning with equivalences: + test_semantic_equiv_distributivity() + test_semantic_equiv_associativity() + # Tests that verify the determinism of the pass: + test_deterministic_cse() + test_deterministic_cse_2() From ebc9b6d41cbb6720654dd1fd54488a88b4a8898d Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Thu, 9 Jun 2022 09:41:02 -0700 Subject: [PATCH 078/181] [ci] Add guards to pytest_wrapper (#11553) This should fix #11544 and adds some more logging in case the issue persists. Unfortunately it is difficult to test for real since the case data in that PR is thrown away after Jenkins is done (Jenkins does store test data but it marshals JUnits into its own format) Co-authored-by: driazati --- Jenkinsfile | 338 +++++++++++++++++++++++++++++++- jenkins/macros.j2 | 12 ++ tests/scripts/git_utils.py | 5 +- tests/scripts/pytest_wrapper.py | 9 +- 4 files changed, 358 insertions(+), 6 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 334448a7ae24b..0205a1e7364fe 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,7 +45,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-06-01T16:34:53.941462 +// Generated at 2022-06-02T14:03:43.284817 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> @@ -847,6 +847,14 @@ def shard_run_unittest_GPU_1_of_3() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -899,6 +907,14 @@ def shard_run_unittest_GPU_2_of_3() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -947,6 +963,14 @@ def shard_run_unittest_GPU_3_of_3() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -994,6 +1018,14 @@ def shard_run_integration_CPU_1_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1040,6 +1072,14 @@ def shard_run_integration_CPU_2_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1086,6 +1126,14 @@ def shard_run_integration_CPU_3_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1132,6 +1180,14 @@ def shard_run_integration_CPU_4_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1178,6 +1234,14 @@ def shard_run_integration_CPU_5_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1224,6 +1288,14 @@ def shard_run_integration_CPU_6_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1271,6 +1343,14 @@ def shard_run_python_i386_1_of_5() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1317,6 +1397,14 @@ def shard_run_python_i386_2_of_5() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1362,6 +1450,14 @@ def shard_run_python_i386_3_of_5() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1407,6 +1503,14 @@ def shard_run_python_i386_4_of_5() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1452,6 +1556,14 @@ def shard_run_python_i386_5_of_5() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1498,6 +1610,14 @@ def shard_run_test_Hexagon_1_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1542,6 +1662,14 @@ def shard_run_test_Hexagon_2_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1586,6 +1714,14 @@ def shard_run_test_Hexagon_3_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1630,6 +1766,14 @@ def shard_run_test_Hexagon_4_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1674,6 +1818,14 @@ def shard_run_test_Hexagon_5_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1718,6 +1870,14 @@ def shard_run_test_Hexagon_6_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1762,6 +1922,14 @@ def shard_run_test_Hexagon_7_of_7() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1808,6 +1976,14 @@ def shard_run_integration_aarch64_1_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1853,6 +2029,14 @@ def shard_run_integration_aarch64_2_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1898,6 +2082,14 @@ def shard_run_integration_aarch64_3_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1943,6 +2135,14 @@ def shard_run_integration_aarch64_4_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -1988,6 +2188,14 @@ def shard_run_topi_GPU_1_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2032,6 +2240,14 @@ def shard_run_topi_GPU_2_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2076,6 +2292,14 @@ def shard_run_topi_GPU_3_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2120,6 +2344,14 @@ def shard_run_topi_GPU_4_of_4() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2165,6 +2397,14 @@ def shard_run_frontend_GPU_1_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2209,6 +2449,14 @@ def shard_run_frontend_GPU_2_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2253,6 +2501,14 @@ def shard_run_frontend_GPU_3_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2297,6 +2553,14 @@ def shard_run_frontend_GPU_4_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2341,6 +2605,14 @@ def shard_run_frontend_GPU_5_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2385,6 +2657,14 @@ def shard_run_frontend_GPU_6_of_6() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2435,6 +2715,14 @@ def shard_run_topi_aarch64_1_of_2() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2483,6 +2771,14 @@ def shard_run_topi_aarch64_2_of_2() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2528,6 +2824,14 @@ def shard_run_frontend_aarch64_1_of_2() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2572,6 +2876,14 @@ def shard_run_frontend_aarch64_2_of_2() { }) } } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2742,6 +3054,14 @@ stage('Test') { ) }) } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2787,6 +3107,14 @@ stage('Test') { ) }) } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } @@ -2827,6 +3155,14 @@ stage('Test') { ) }) } finally { + sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) + junit 'build/pytest-results/*.xml' } } diff --git a/jenkins/macros.j2 b/jenkins/macros.j2 index 5a641b73fea84..5d996ce19a559 100644 --- a/jenkins/macros.j2 +++ b/jenkins/macros.j2 @@ -19,6 +19,16 @@ "workspace/exec_${env.EXECUTOR_NUMBER}/{{ folder }}" {%- endmacro -%} +{% macro junit_to_s3() %} +sh( + script: """ + set -eux + aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results --recursive + """, + label: 'Upload JUnits to S3', + ) +{% endmacro %} + {% macro sharded_test_step(name, num_shards, node, ws, docker_image, platform, test_method_names) %} {% for shard_index in range(1, num_shards + 1) %} @@ -39,6 +49,7 @@ def {{ method_name }}() { }) } } finally { + {{ junit_to_s3() }} junit 'build/pytest-results/*.xml' } } @@ -86,6 +97,7 @@ def {{ method_name }}() { {{ caller() | indent(width=12) | trim }} }) } finally { + {{ junit_to_s3() | indent(width=4) }} junit 'build/pytest-results/*.xml' } } diff --git a/tests/scripts/git_utils.py b/tests/scripts/git_utils.py index 0e2e85e552431..267756d859050 100644 --- a/tests/scripts/git_utils.py +++ b/tests/scripts/git_utils.py @@ -36,7 +36,7 @@ def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = req = request.Request(url, headers=headers, method="POST") if auth is not None: auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}".encode()) - req.add_header("Authorization", f"Basic {auth_str}") + req.add_header("Authorization", f"Basic {auth_str.decode()}") if body is None: body = "" @@ -47,8 +47,7 @@ def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = req.add_header("Content-Length", len(data)) with request.urlopen(req, data) as response: - response = json.loads(response.read()) - return response + return response.read() class GitHubRepo: diff --git a/tests/scripts/pytest_wrapper.py b/tests/scripts/pytest_wrapper.py index a7b6f0dfa766d..4c4410bedc9c6 100755 --- a/tests/scripts/pytest_wrapper.py +++ b/tests/scripts/pytest_wrapper.py @@ -18,6 +18,7 @@ import argparse import textwrap import junitparser +import traceback from pathlib import Path from typing import List, Optional import os @@ -51,6 +52,10 @@ def failed_test_ids() -> List[str]: for suite in xml: # handle suites for case in suite: + if case.result is None: + logging.warn(f"Incorrectly formatted JUnit found, result was None on {case}") + continue + if len(case.result) > 0 and isinstance(case.result[0], FAILURE_TYPES): node_id = classname_to_file(case.classname) + "::" + case.name failed_node_ids.append(node_id) @@ -112,7 +117,7 @@ def show_failure_help(failed_suites: List[str]) -> None: "If there is no test listed below, the failure likely came from a segmentation " "fault which you can find in the logs above.\n" ) - if len(failed_suites) > 0: + if failed_suites is not None and len(failed_suites) > 0: print("\n".join([f" - {suite}" for suite in failed_suites])) print("") @@ -131,4 +136,4 @@ def show_failure_help(failed_suites: List[str]) -> None: except Exception as e: # This script shouldn't ever introduce failures since it's just there to # add extra information, so ignore any errors - logging.error(str(e)) + logging.exception(e) From 87502ddd9002cdfe1035a2bc1c7063e33098ced1 Mon Sep 17 00:00:00 2001 From: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Thu, 9 Jun 2022 10:14:46 -0700 Subject: [PATCH 079/181] [PASS] Refactor a couple of TIR passes - BindTarget, AnnotateEntryFunc, Filter, LowerInitBlock (#11628) This PR fixes a few inconsistent pass registration and add testcases for them. - `LowerInitBlock` had mismatch between its pass name and ffi key. - `BindTarget`, `AnnotateEntryFunc`, `Filter` were not following the name convention of tir passes and they were not registered in FFI registry. --- include/tvm/tir/transform.h | 19 +++ python/tvm/tir/transform/transform.py | 61 ++++++--- src/driver/driver_api.cc | 45 ++----- src/tir/transforms/lower_init_block.cc | 2 +- src/tir/transforms/primfunc_utils.cc | 63 +++++++++ .../convert_pool_allocations_to_offsets.cc | 2 +- .../unittest/test_tir_transform_helpers.py | 123 ++++++++++++++++++ 7 files changed, 258 insertions(+), 57 deletions(-) create mode 100644 src/tir/transforms/primfunc_utils.cc create mode 100644 tests/python/unittest/test_tir_transform_helpers.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 4612d5ad3feac..6393eeb9430b9 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -25,6 +25,7 @@ #define TVM_TIR_TRANSFORM_H_ #include +#include #include #include @@ -625,6 +626,24 @@ TVM_DLL Pass ExtractPrimFuncConstants(); */ TVM_DLL Pass RenormalizeSplitPattern(); +/*! + * \brief Annotate a PrimFunc with a given target. + * \return The pass. + */ +TVM_DLL Pass BindTarget(Target target); + +/*! + * \brief Set a PrimFunc as the entry point if it is only function in IRModule. + * \return The pass. + */ +TVM_DLL Pass AnnotateEntryFunc(); + +/*! + * \brief Filter PrimFuncs with a given condition. + * \return The pass. + */ +TVM_DLL Pass Filter(runtime::TypedPackedFunc fcond); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1bed29c560fc9..e0a7501ef92af 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -16,7 +16,8 @@ # under the License. """Wrapping existing transformations.""" # pylint: disable=invalid-name -from typing import Optional +from typing import Optional, Callable + from . import _ffi_api from . import function_pass as _fpass @@ -43,26 +44,6 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore -def Filter(fcond): - """Filter functions by the calling convention attribute. - - Parameters - ---------- - fcond : tvm.tir.PrimFunc -> bool - The condition of the filtering. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - # pylint: disable=unused-argument - def _transform(func, mod, ctx): - return func if fcond(func) else None - - return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") # type: ignore - - def InjectPrefetch(): """Inject prefetch instructions into stmt. @@ -806,3 +787,41 @@ def RenormalizeSplitPattern(): The result pass """ return _ffi_api.RenormalizeSplitPattern() # type: ignore + + +def BindTarget(target): + """Annotate a PrimFunc with a given target. + Parameters + ------- + target : tvm.target.Target + target + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BindTarget(target) # type: ignore + + +def AnnotateEntryFunc(): + """Set a PrimFunc as the entry point if it is only function in IRModule. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateEntryFunc() # type: ignore + + +def Filter(fcond: Callable): + """Filter out PrimFuncs that does not satisfy the given condition. + `fcond` should be a function that takes a primfunc and returns boolean. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.Filter(fcond) # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7706f229c9ed3..ace31800de27f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -164,32 +164,6 @@ TVM_REGISTER_GLOBAL("driver.get_binds") return out_arr; }); -transform::Pass BindTarget(Target target) { - auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - return WithAttr(std::move(f), tvm::attr::kTarget, target); - }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); -} - -static transform::Pass AnnotateEntryFunc(bool b) { - auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); - }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); -} - -template -transform::Pass Filter(FCond fcond) { - auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - if (fcond(f)) { - return f; - } else { - return tir::PrimFunc(nullptr); - } - }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); -} - Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -564,12 +538,12 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; - mixed_pass_list.push_back(BindTarget(target)); + mixed_pass_list.push_back(tir::transform::BindTarget(target)); mixed_pass_list.push_back(tir::transform::VerifyMemory()); if (ShouldAnnotateEntryFunc(mixed_mod)) { - mixed_pass_list.push_back(AnnotateEntryFunc(true)); + mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); } bool detect_global_barrier = @@ -606,14 +580,16 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { Array host_pass_list; - host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + + runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; - })); + }; + host_pass_list.push_back(tir::transform::Filter(fcond)); ICHECK(mixed_mod.defined()) << "This module must be defined"; - host_pass_list.push_back(BindTarget(target_host)); + host_pass_list.push_back(tir::transform::BindTarget(target_host)); host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); @@ -631,12 +607,13 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes") transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { Array device_pass_list; - device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; - })); + }; + device_pass_list.push_back(tir::transform::Filter(fcond)); - device_pass_list.push_back(BindTarget(target)); + device_pass_list.push_back(tir::transform::BindTarget(target)); device_pass_list.push_back(tir::transform::LowerWarpMemory()); device_pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index d8621ac3b3e6d..17b4e3fb22e62 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -81,7 +81,7 @@ Pass LowerInitBlock() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { return LowerInitBlock(std::move(f)); }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc new file mode 100644 index 0000000000000..d2bb259f9921f --- /dev/null +++ b/src/tir/transforms/primfunc_utils.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file primfunc_utils.cc + * \brief Passes that serve as helper functions. + */ + +#include +#include + +namespace tvm { +namespace tir { +namespace transform { +transform::Pass BindTarget(Target target) { + auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + return WithAttr(std::move(f), tvm::attr::kTarget, target); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {}); +} + +transform::Pass AnnotateEntryFunc() { + auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + ICHECK(m->functions.size() == 1); + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.AnnotateEntryFunc", {}); +} + +transform::Pass Filter(runtime::TypedPackedFunc fcond) { + auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + if (fcond(f)) { + return f; + } else { + return tir::PrimFunc(nullptr); + } + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); +TVM_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); +TVM_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index dc71e3d60891c..1161962f12872 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -60,7 +60,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { PoolInfo pool_info = pool_allocation->pool_info; int byte_pool_offset = pool_allocation->byte_offset->value; int required_pool_size_for_allocation = - byte_pool_offset + CalculateExtentsSize(allocate_node.operator->()); + byte_pool_offset + static_cast(CalculateExtentsSize(allocate_node.operator->())); if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { all_pools_sizes_[pool_info] = required_pool_size_for_allocation; } else { diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py new file mode 100644 index 0000000000000..01496e0e0fc13 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.script import tir as T +import tvm.testing + + +def test_annotate_entry_func_single_primfunc(): + @tvm.script.ir_module + class MockModule: + @T.prim_func + def func1(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + if i == 5: + A[i] = 0.0 + + mod = MockModule + assert mod + assert mod["func1"].attrs is None + after = tvm.tir.transform.AnnotateEntryFunc()(mod) + assert ( + after["func1"].attrs + and "tir.is_entry_func" in after["func1"].attrs + and after["func1"].attrs["tir.is_entry_func"] + ) + + +# Test module +@tvm.script.ir_module +class MockModule: + @T.prim_func + def func1(A: T.Buffer[(16,), "float32"]): + for i in T.serial(16): + if i == 5: + if i == 5: + A[i] = 0.0 + + @T.prim_func + def func2(A: T.Buffer[(32,), "float32"]): + for i in T.serial(32): + if i == 15: + if i == 15: + A[i] = 0.0 + + +@pytest.mark.xfail +def test_annotate_entry_func_multiple_primfunc(): + mod = MockModule + assert mod + assert mod["func1"].attrs is None + assert mod["func2"].attrs is None + # This should fail + after = tvm.tir.transform.AnnotateEntryFunc()(mod) + + +def test_bind_target(): + mod = MockModule + assert mod + + target = tvm.target.Target("cuda") + assert mod["func1"].attrs is None + assert mod["func2"].attrs is None + after = tvm.tir.transform.BindTarget(target)(mod) + + assert after["func1"].attrs and "target" in after["func1"].attrs + assert after["func1"].attrs["target"] == target + assert after["func2"].attrs and "target" in after["func2"].attrs + assert after["func2"].attrs["target"] == target + + +def test_filter_primfunc(): + mod = MockModule + assert mod + # Annotate each function for testing + mod["func1"] = mod["func1"].with_attr("temp", "test1") + mod["func2"] = mod["func2"].with_attr("temp", "test2") + + # Test condition that does not filter out anything + def checker_filter_out_none(func: tvm.tir.PrimFunc): + return (func.attrs is not None) and ("temp" in func.attrs) + + after = tvm.tir.transform.Filter(checker_filter_out_none)(mod) + assert len(after.functions) == 2 + # Filtered functions should satisfy the given condition. + assert checker_filter_out_none(after["func1"]) + assert checker_filter_out_none(after["func2"]) + + # Test condition that selectively filters out primfuncs + def checker_filter_out_one(func: tvm.tir.PrimFunc): + return (func.attrs is not None) and ("temp" in func.attrs) and func.attrs["temp"] == "test1" + + after = tvm.tir.transform.Filter(checker_filter_out_one)(mod) + assert len(after.functions) == 1 + # Filtered functions should satisfy the given condition. + assert checker_filter_out_one(after["func1"]) + + # Test condition that filters out everything + def checker_filter_out_both(func: tvm.tir.PrimFunc): + return (func.attrs is not None) and ("invalid_attr" in func.attrs) + + after = tvm.tir.transform.Filter(checker_filter_out_both)(mod) + assert len(after.functions) == 0 + + +if __name__ == "__main__": + tvm.testing.main() From 7f1b819cdbc70fabaabe9374932e98a3c4bc4660 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 9 Jun 2022 11:17:09 -0600 Subject: [PATCH 080/181] [microTVM] Remove microTVM RVM version suffix (#11629) --- apps/microtvm/reference-vm/base-box-tool.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index a4777c3ff86f4..db89f323328e1 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -479,7 +479,7 @@ def release_command(args): if args.release_full_name: vm_name = args.release_full_name else: - vm_name = f"tlcpack/microtvm-{args.platform}-{args.platform_version}" + vm_name = f"tlcpack/microtvm-{args.platform}" if not args.skip_creating_release_version: subprocess.check_call( @@ -604,14 +604,6 @@ def parse_args(): action="store_true", help="Skip creating the version and just upload for this provider.", ) - parser_release.add_argument( - "--platform-version", - required=False, - help=( - "For Zephyr, the platform version to release, in the form 'x.y'. " - "For Arduino, the version of arduino-cli that's being used, in the form 'x.y.z'." - ), - ) parser_release.add_argument( "--release-full-name", required=False, @@ -619,15 +611,11 @@ def parse_args(): default=None, help=( "If set, it will use this as the full release name and version for the box. " - "If this set, it will ignore `--platform-version` and `--release-version`." + "If this set, it will ignore `--release-version`." ), ) args = parser.parse_args() - - if args.action == "release" and not args.release_full_name: - parser.error("--platform-version is requireed.") - return args From f528a9a1cd5a0145e07b0bebcc43ab9020767cc9 Mon Sep 17 00:00:00 2001 From: czh978 <41666381+czh978@users.noreply.github.com> Date: Fri, 10 Jun 2022 01:33:44 +0800 Subject: [PATCH 081/181] [Frontend][TFLite] Improve support for half_pixel_centers in resize (#11521) * add resize_nearest_neighbor op test * Improve support for half_pixel_centers in resize --- python/tvm/relay/frontend/tflite.py | 7 ++++++- tests/python/frontend/tflite/test_forward.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 342c4e2ae553a..981074b6adb24 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -695,10 +695,15 @@ def _convert_resize(self, method, op): coord_trans = "align_corners" if align_corners else "asymmetric" coord_trans = "half_pixel" if half_pixel_centers else coord_trans + rounding_method = "" + if method == "nearest_neighbor": + if not align_corners and half_pixel_centers: + rounding_method = "round_prefer_ceil" + if bilinear_method and input_tensor.qnn_params: in_expr = self.dequantize(in_expr, input_tensor) out = _op.image.resize2d( - in_expr, target_size, None, "NHWC", method, coordinate_transformation_mode=coord_trans + in_expr, target_size, None, "NHWC", method, coord_trans, rounding_method ) if bilinear_method and output_tensor.qnn_params: out = self.quantize(out, output_tensor) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 8b0244d75eda8..76b0766dae284 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1693,6 +1693,20 @@ def test_all_resize(): align_corners=False, half_pixel_centers=False, ) + _test_resize( + tf.image.resize_nearest_neighbor, + images_data_float32, + size_data, + align_corners=True, + half_pixel_centers=False, + ) + _test_resize( + tf.image.resize_nearest_neighbor, + images_data_float32, + size_data, + align_corners=False, + half_pixel_centers=True, + ) ####################################################################### From 81b42e67460f11955794f7fc48465b15f16ae57b Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Thu, 9 Jun 2022 19:02:31 +0100 Subject: [PATCH 082/181] Making CMSIS-NN tests pylint compliant (#11625) --- tests/lint/pylint.sh | 2 + tests/python/contrib/test_cmsisnn/__init__.py | 17 ++ .../contrib/test_cmsisnn/test_binary_ops.py | 22 +- .../contrib/test_cmsisnn/test_conv2d.py | 25 +- .../test_cmsisnn/test_extract_constants.py | 217 ++++++++++-------- .../test_cmsisnn/test_fully_connected.py | 28 ++- .../test_cmsisnn/test_generate_constants.py | 19 +- .../test_cmsisnn/test_invalid_graphs.py | 14 +- .../contrib/test_cmsisnn/test_networks.py | 22 +- .../contrib/test_cmsisnn/test_pooling.py | 17 +- .../test_scalar_to_tensor_constant.py | 201 ++++++++-------- .../contrib/test_cmsisnn/test_softmax.py | 11 +- tests/python/contrib/test_cmsisnn/utils.py | 7 +- 13 files changed, 326 insertions(+), 276 deletions(-) create mode 100644 tests/python/contrib/test_cmsisnn/__init__.py diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh index 6c958a9231395..b442c33c0ff67 100755 --- a/tests/lint/pylint.sh +++ b/tests/lint/pylint.sh @@ -20,3 +20,5 @@ set -euxo pipefail python3 -m pylint python/tvm --rcfile="$(dirname "$0")"/pylintrc python3 -m pylint vta/python/vta --rcfile="$(dirname "$0")"/pylintrc python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile="$(dirname "$0")"/pylintrc +python3 -m pylint tests/python/contrib/test_cmsisnn --rcfile="$(dirname "$0")"/pylintrc + diff --git a/tests/python/contrib/test_cmsisnn/__init__.py b/tests/python/contrib/test_cmsisnn/__init__.py new file mode 100644 index 0000000000000..f9a622464a479 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for CMSIS-NN""" diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index 49c76870157ea..fec18c197e045 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -18,17 +18,19 @@ """CMSIS-NN integration tests: binary ops""" import itertools -import sys import numpy as np -from enum import Enum import pytest import tvm from tvm import relay from tvm.relay.op.contrib import cmsisnn +from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run +from tvm.micro.testing.aot_test_utils import ( + AOT_USMP_CORSTONE300_RUNNER, +) -from utils import ( +from .utils import ( skip_if_no_reference_system, make_module, make_qnn_relu, @@ -36,11 +38,6 @@ assert_partitioned_function, assert_no_external_function, ) -from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run -from tvm.micro.testing.aot_test_utils import ( - AOT_CORSTONE300_RUNNER, - AOT_USMP_CORSTONE300_RUNNER, -) def generate_tensor_constant(): @@ -104,6 +101,7 @@ def make_model( def test_op_int8( op, relu_type, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point ): + """Tests QNN Conv2D operator for CMSIS-NN""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -147,8 +145,10 @@ def test_op_int8( ) -# At least one of the inputs is a constant, both can't be variables, both can't be scalars def parameterize_for_constant_inputs(test): + """Generates parameters in such a way so that at least one of the inputs is a constant, + both can't be variables, both can't be scalars. + """ op = [relay.qnn.op.mul, relay.qnn.op.add] input_0 = [generate_variable("input_0"), generate_tensor_constant(), generate_scalar_constant()] input_1 = [generate_variable("input_1"), generate_tensor_constant(), generate_scalar_constant()] @@ -178,6 +178,7 @@ def parameterize_for_constant_inputs(test): @tvm.testing.requires_cmsisnn @parameterize_for_constant_inputs def test_constant_input_int8(op, input_0, input_1): + """Tests binary ops where one of the operands is a constant""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -231,9 +232,9 @@ def test_constant_input_int8(op, input_0, input_1): def test_both_scalar_inputs_int8( op, ): + """Tests binary ops where both operands are scalars""" input_scale = 0.256 input_zero_point = 33 - dtype = "int8" model = make_model( op, generate_scalar_constant(), @@ -257,6 +258,7 @@ def test_invalid_parameters( op, input_dtype, ): + """Tests binary ops for non int8 dtypes""" input_scale = 0.256 input_zero_point = 33 model = make_model( diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 90261e540a7d6..462eb88347194 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -26,8 +26,7 @@ from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_models, compile_and_run from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER -from utils import ( - skip_if_no_reference_system, +from .utils import ( make_module, get_range_for_dtype_str, get_same_padding, @@ -76,7 +75,7 @@ def make_model( shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) rng = np.random.default_rng(12321) - w = tvm.nd.array( + weight = tvm.nd.array( rng.integers( np.iinfo(kernel_dtype).min, high=np.iinfo(kernel_dtype).max, @@ -84,7 +83,7 @@ def make_model( dtype=kernel_dtype, ) ) - weight_const = relay.const(w, kernel_dtype) + weight_const = relay.const(weight, kernel_dtype) conv = relay.qnn.op.conv2d( invar, weight_const, @@ -102,8 +101,8 @@ def make_model( padding=p, out_dtype="int32", ) - b = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) - bias_const = relay.const(b, "int32") + bias = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) + bias_const = relay.const(bias, "int32") last_op = relay.nn.bias_add(conv, bias_const, axis=3) if enable_bias else conv requant_input_sc = [sc * input_scale for sc in kernel_scale] last_op = relay.qnn.op.requantize( @@ -115,7 +114,7 @@ def make_model( out_dtype=dtype, ) last_op = make_qnn_relu(last_op, relu_type, output_scale, output_zero_point, dtype) - params = {"w": w, "b": b} + params = {"w": weight, "b": bias} return last_op, params @@ -134,9 +133,9 @@ def test_conv2d_number_primfunc_args( kernel_scale, out_channels, ): + """Tests number of arguments in Conv2D primfunc""" interface_api = "c" use_unpacked_api = True - test_runner = AOT_USMP_CORSTONE300_RUNNER ifm_shape = (1, 64, 100, 4) kernel_size = (3, 3) @@ -204,7 +203,7 @@ def test_conv2d_number_primfunc_args( expected_num_params = 6 if enable_bias else 5 cmsisnn_tir_mod = None for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items(): - if "cmsis-nn" == target.kind.name: + if target.kind.name == "cmsis-nn": cmsisnn_tir_mod = mod cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"] @@ -230,6 +229,7 @@ def test_conv2d_symmetric_padding_int8( kernel_scale, out_channels, ): + """Tests QNN Conv2D where the padding is symmetric on both sides of input""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -319,6 +319,7 @@ def test_conv2d_asymmetric_padding_int8( kernel_scale, out_channels, ): + """Tests QNN Conv2D where the padding is asymmetric on different sides of input""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -390,6 +391,7 @@ def test_conv2d_asymmetric_padding_int8( ) +# pylint: disable=import-outside-toplevel @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)]) @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @@ -397,6 +399,7 @@ def test_conv2d_asymmetric_padding_int8( @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding, activation): + """Compares TVM output against TFLite output""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -460,6 +463,7 @@ def test_depthwise_int8( out_channels, depth_multiplier, ): + """Tests QNN Depthwise int8 op via CMSIS-NN""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -537,6 +541,7 @@ def test_depthwise_int8( def parameterize_for_invalid_model(test): + """Generates non int8 inputs""" in_dtype = ["uint8", "int8"] kernel_dtype = ["uint8", "int8"] kernel_zero_point = [-33, 10, 0] @@ -560,12 +565,12 @@ def test_invalid_parameters( kernel_dtype, kernel_zero_point, ): + """Tests Depthwise op for non int8 inputs""" ifm_shape = (1, 28, 28, 12) out_channels = 2 input_scale = 1 input_zero_point = 24 kernel_scale = [0.11, 0.0237] - in_min, in_max = get_range_for_dtype_str(in_dtype) kernel_layout = "HWIO" kernel_shape = [3, 3, ifm_shape[3], out_channels] diff --git a/tests/python/contrib/test_cmsisnn/test_extract_constants.py b/tests/python/contrib/test_cmsisnn/test_extract_constants.py index 789d400faf978..8831596d40e63 100644 --- a/tests/python/contrib/test_cmsisnn/test_extract_constants.py +++ b/tests/python/contrib/test_cmsisnn/test_extract_constants.py @@ -16,8 +16,6 @@ # under the License. """CMSIS-NN integration tests: extract_constants pass""" -import itertools -import math import numpy as np import pytest import tvm @@ -28,6 +26,8 @@ class CheckFunctionsForConstants(tvm.relay.ExprVisitor): + """Provides methods to test number of constants present in a function""" + def __init__(self): super().__init__() self.num_constants_ = 0 @@ -38,7 +38,7 @@ def visit_call(self, call): if isinstance(arg, relay.Constant) and arg.data.numpy().ndim > 0: self.num_constants_ += 1 - def check_num_constants(self, func): + def check_num_constants(self): assert self.num_constants_ == 0, "Functions should not have constant arguments in Calls" @@ -56,118 +56,132 @@ def set_composite_func_attr(func, name): @tvm.testing.requires_cmsisnn def test_external_function(): - y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") - x0 = relay.var("x0", shape=(8, 8)) - y0_const = relay.const(y0_data, "float32") - z0 = x0 + y0_const - ef = relay.Function([x0], z0, relay.TensorType((8, 8), "float32")) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - - x = relay.var("x", shape=(8, 8)) - c = relay.Call(ev, [x]) - mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) - mv = relay.GlobalVar("main") + """Tests the pass ExternConstants when the function is a global function""" + input1_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + input0 = relay.var("input0", shape=(8, 8)) + input1_const = relay.const(input1_data, "float32") + binary_op = input0 + input1_const + extern_func = relay.Function([input0], binary_op, relay.TensorType((8, 8), "float32")) + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + + arg = relay.var("arg", shape=(8, 8)) + call_extern_func = relay.Call(global_var, [arg]) + main_func = relay.Function([arg], call_extern_func, relay.TensorType((8, 8), "float32")) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = ExtractConstantsFromPartitionedFunction()(mod) - CheckFunctionsForConstants().check_num_constants(mod[ev]) + constant_verifier = CheckFunctionsForConstants() + constant_verifier.visit_function(mod[global_var]) + constant_verifier.check_num_constants() relay.transform.InferType()(mod) @tvm.testing.requires_cmsisnn def test_nested_function(): - y1_data = np.random.uniform(0, 1, (8, 8)).astype("float32") - x1 = relay.var("x1", shape=(8, 8)) - y1_const = relay.const(y1_data, "float32") - z1 = x1 + y1_const - w1 = z1 * relay.const(5.0, "float32") - lf = relay.Function([x1], w1, relay.TensorType((8, 8), "float32")) - lf = set_composite_func_attr(lf, "cmsis-nn") - - x0 = relay.var("x0", shape=(8, 8)) - c0 = relay.Call(lf, [x0]) - ef = relay.Function([x0], c0, relay.TensorType((8, 8), "float32")) - - x = relay.var("x", shape=(8, 8)) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - c = relay.Call(ev, [x]) - mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) - mv = relay.GlobalVar("main") + """Tests the pass ExternConstants when a composite function + is present within global function + """ + input1_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + input0 = relay.var("input0", shape=(8, 8)) + input1_const = relay.const(input1_data, "float32") + binary_op0 = input0 + input1_const + binary_op1 = binary_op0 * relay.const(5.0, "float32") + local_func = relay.Function([input0], binary_op1, relay.TensorType((8, 8), "float32")) + local_func = set_composite_func_attr(local_func, "cmsis-nn") + + arg = relay.var("arg", shape=(8, 8)) + call_local_func = relay.Call(local_func, [arg]) + extern_func = relay.Function([arg], call_local_func, relay.TensorType((8, 8), "float32")) + + global_arg = relay.var("garg", shape=(8, 8)) + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, [global_arg]) + main_func = relay.Function([global_arg], call_extern_func, relay.TensorType((8, 8), "float32")) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = ExtractConstantsFromPartitionedFunction()(mod) - CheckFunctionsForConstants().check_num_constants(mod[ev]) + constant_verifier = CheckFunctionsForConstants() + constant_verifier.visit_function(mod[global_var]) + constant_verifier.check_num_constants() relay.transform.InferType()(mod) @tvm.testing.requires_cmsisnn def test_multiple_functions(): - y20_data = np.random.uniform(0, 1, (8, 8)).astype("float32") - x20 = relay.var("x20", shape=(8, 8)) - y20_const = relay.const(y20_data, "float32") - z20 = x20 + y20_const - f20 = relay.Function([x20], z20, relay.TensorType((8, 8), "float32")) - f20 = set_composite_func_attr(f20, "cmsis-nn") - - y21_data = np.random.uniform(0, 1, (8, 8)).astype("float32") - x21 = relay.var("x21", shape=(8, 8)) - y21_const = relay.const(y21_data, "float32") - z21 = x21 + y21_const - f21 = relay.Function([x21], z21, relay.TensorType((8, 8), "float32")) - f21 = set_composite_func_attr(f21, "cmsis-nn") - - x10 = relay.var("x10", shape=(8, 8)) - c10 = relay.Call(f20, [x10]) - c11 = relay.Call(f21, [c10]) - ef = relay.Function([x10], c11, relay.TensorType((8, 8), "float32")) - x0 = relay.var("x0", shape=(8, 8)) - ev = relay.GlobalVar("cmsis-nn") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - c = relay.Call(ev, [x0]) - mf = relay.Function([x0], c, relay.TensorType((8, 8), "float32")) - mv = relay.GlobalVar("main") + """Tests the pass ExternConstants when global function + contains multiple composite functions inside it + """ + f0_input1_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + f0_input0 = relay.var("f0_in0", shape=(8, 8)) + f0_input1_const = relay.const(f0_input1_data, "float32") + f0_binary_op = f0_input0 + f0_input1_const + f0_func = relay.Function([f0_input0], f0_binary_op, relay.TensorType((8, 8), "float32")) + f0_func = set_composite_func_attr(f0_func, "cmsis-nn") + + f1_input1_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + f1_input0 = relay.var("f1_in0", shape=(8, 8)) + f1_input1_const = relay.const(f1_input1_data, "float32") + f1_binary_op = f1_input0 + f1_input1_const + f1_func = relay.Function([f1_input0], f1_binary_op, relay.TensorType((8, 8), "float32")) + f1_func = set_composite_func_attr(f1_func, "cmsis-nn") + + arg0 = relay.var("arg0", shape=(8, 8)) + call_local_func0 = relay.Call(f0_func, [arg0]) + call_local_func1 = relay.Call(f1_func, [call_local_func0]) + extern_func = relay.Function([arg0], call_local_func1, relay.TensorType((8, 8), "float32")) + input0 = relay.var("input0", shape=(8, 8)) + global_var = relay.GlobalVar("cmsis-nn") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, [input0]) + main_func = relay.Function([input0], call_extern_func, relay.TensorType((8, 8), "float32")) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = ExtractConstantsFromPartitionedFunction()(mod) - CheckFunctionsForConstants().check_num_constants(mod[ev]) + constant_verifier = CheckFunctionsForConstants() + constant_verifier.visit_function(mod[global_var]) + constant_verifier.check_num_constants() relay.transform.InferType()(mod) @tvm.testing.requires_cmsisnn def test_main_function(): - x0 = relay.var("x0", shape=(8, 8)) - y0 = relay.var("y0", shape=(8, 8)) - z0 = x0 + y0 - ef = relay.Function([x0, y0], z0, relay.TensorType((8, 8), "float32")) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - - x = relay.var("x", shape=(8, 8)) - y_data = np.random.uniform(0, 1, (8, 8)).astype("float32") - y_const = relay.const(y_data, "float32") - z = x + y_const - c = relay.Call(ev, [x, z]) - mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) - mv = relay.GlobalVar("main") + """Tests the pass ExternConstants on main function""" + input0 = relay.var("input0", shape=(8, 8)) + input1 = relay.var("input1", shape=(8, 8)) + binary_op = input0 + input1 + extern_func = relay.Function([input0, input1], binary_op, relay.TensorType((8, 8), "float32")) + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + + arg = relay.var("arg", shape=(8, 8)) + input_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + input_const = relay.const(input_data, "float32") + binary_op = arg + input_const + call_extern_func = relay.Call(global_var, [arg, binary_op]) + main_func = relay.Function([arg], call_extern_func, relay.TensorType((8, 8), "float32")) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = ExtractConstantsFromPartitionedFunction()(mod) check_for_constants = CheckFunctionsForConstants() - check_for_constants.visit_call(mod[mv].body) + check_for_constants.visit_call(mod[main_var].body) assert ( check_for_constants.num_constants_ == 1 ), "main() should have same number of arguments as before" @@ -176,6 +190,7 @@ def test_main_function(): @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("external_compiler", ["cmsis-nn", "other_compiler"]) def test_multiple_functions_non_cmsisnn_compiler(external_compiler): + """Tests the pass ExternConstants on non CMSIS-NN targets""" y20_data = np.random.uniform(0, 1, (8, 8)).astype("float32") x20 = relay.var("x20", shape=(8, 8)) y20_const = relay.const(y20_data, "float32") @@ -183,8 +198,8 @@ def test_multiple_functions_non_cmsisnn_compiler(external_compiler): f20 = relay.Function([x20], z20, relay.TensorType((8, 8), "float32")) f20 = set_composite_func_attr(f20, "cmsis-nn.qnn_op_1") x10 = relay.var("x10", shape=(8, 8)) - c10 = relay.Call(f20, [x10]) - ef0 = relay.Function([x10], c10, relay.TensorType((8, 8), "float32")) + call_local_func0 = relay.Call(f20, [x10]) + extern_func0 = relay.Function([x10], call_local_func0, relay.TensorType((8, 8), "float32")) y21_data = np.random.uniform(0, 1, (8, 8)).astype("float32") x21 = relay.var("x21", shape=(8, 8)) @@ -193,27 +208,27 @@ def test_multiple_functions_non_cmsisnn_compiler(external_compiler): f21 = relay.Function([x21], z21, relay.TensorType((8, 8), "float32")) f21 = set_composite_func_attr(f21, "cmsis-nn.qnn_op_2") x11 = relay.var("x11", shape=(8, 8)) - c11 = relay.Call(f21, [x11]) - ef1 = relay.Function([x11], c11, relay.TensorType((8, 8), "float32")) - - x0 = relay.var("x0", shape=(8, 8)) - ev0 = relay.GlobalVar("external_function_0") - ef0 = set_external_func_attr(ef0, external_compiler, ev0.name_hint) - c0 = relay.Call(ev0, [x0]) - ev1 = relay.GlobalVar("external_function_1") - ef1 = set_external_func_attr(ef1, external_compiler, ev1.name_hint) - c1 = relay.Call(ev1, [c0]) - mf = relay.Function([x0], c1, relay.TensorType((8, 8), "float32")) - mv = relay.GlobalVar("main") + call_local_func1 = relay.Call(f21, [x11]) + extern_func1 = relay.Function([x11], call_local_func1, relay.TensorType((8, 8), "float32")) + + input0 = relay.var("input0", shape=(8, 8)) + global_var0 = relay.GlobalVar("external_function_0") + extern_func0 = set_external_func_attr(extern_func0, external_compiler, global_var0.name_hint) + call_extern_func0 = relay.Call(global_var0, [input0]) + global_var1 = relay.GlobalVar("external_function_1") + extern_func1 = set_external_func_attr(extern_func1, external_compiler, global_var1.name_hint) + call_extern_func1 = relay.Call(global_var1, [call_extern_func0]) + main_func = relay.Function([input0], call_extern_func1, relay.TensorType((8, 8), "float32")) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev0] = ef0 - mod[ev1] = ef1 - mod[mv] = mf + mod[global_var0] = extern_func0 + mod[global_var1] = extern_func1 + mod[main_var] = main_func mod = ExtractConstantsFromPartitionedFunction()(mod) check_for_constants = CheckFunctionsForConstants() - check_for_constants.visit_call(mod[mv].body) + check_for_constants.visit_call(mod[main_var].body) num_extracted_constants = 0 if external_compiler == "cmsis-nn": diff --git a/tests/python/contrib/test_cmsisnn/test_fully_connected.py b/tests/python/contrib/test_cmsisnn/test_fully_connected.py index c5d97f807b046..3a2061096dc12 100644 --- a/tests/python/contrib/test_cmsisnn/test_fully_connected.py +++ b/tests/python/contrib/test_cmsisnn/test_fully_connected.py @@ -27,11 +27,9 @@ from tvm.micro.testing.aot_test_utils import ( AOT_USMP_CORSTONE300_RUNNER, ) -from utils import ( - skip_if_no_reference_system, +from .utils import ( make_module, get_range_for_dtype_str, - get_same_padding, get_conv2d_qnn_params, make_qnn_relu, assert_partitioned_function, @@ -55,9 +53,9 @@ def make_model( relu_type="NONE", ): """Return a model and any parameters it may have""" - a = relay.var("input", shape=in_shape, dtype=dtype) + input_ = relay.var("input", shape=in_shape, dtype=dtype) rng = np.random.default_rng(12321) - w = tvm.nd.array( + weight = tvm.nd.array( rng.integers( np.iinfo(kernel_dtype).min, high=np.iinfo(kernel_dtype).max, @@ -65,9 +63,9 @@ def make_model( dtype=kernel_dtype, ) ) - weight_const = relay.const(w, kernel_dtype) - fc = relay.qnn.op.dense( - a, + weight_const = relay.const(weight, kernel_dtype) + dense = relay.qnn.op.dense( + input_, weight_const, input_zero_point=relay.const(input_zero_point, "int32"), kernel_zero_point=relay.const(kernel_zero_point, "int32"), @@ -77,9 +75,9 @@ def make_model( out_dtype="int32", ) - b = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) - bias_const = relay.const(b, "int32") - last_op = relay.nn.bias_add(fc, bias_const) if enable_bias else fc + bias = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) + bias_const = relay.const(bias, "int32") + last_op = relay.nn.bias_add(dense, bias_const) if enable_bias else dense requant_input_sc = input_scale * kernel_scale last_op = relay.qnn.op.requantize( last_op, @@ -90,7 +88,7 @@ def make_model( out_dtype=dtype, ) last_op = make_qnn_relu(last_op, relu_type, output_scale, output_zero_point, dtype) - params = {"w": w, "b": b} + params = {"w": weight, "b": bias} return last_op, params @@ -98,7 +96,6 @@ def make_model( @pytest.mark.parametrize("in_shape", [(2, 28), (1, 64)]) @pytest.mark.parametrize("out_channels", [12, 128]) @pytest.mark.parametrize("enable_bias", [False, True]) -@pytest.mark.parametrize("relu_type", ["RELU"]) @pytest.mark.parametrize( "input_zero_point, input_scale, kernel_scale", [(10, 0.0128, 0.11), (-64, 0.0256, 1.37)], @@ -110,8 +107,8 @@ def test_op_int8( input_scale, kernel_scale, out_channels, - relu_type, ): + """Test QNN fully connected layer""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -170,6 +167,7 @@ def test_op_int8( def parameterize_for_invalid_model(test): + """Generates parameters for non int8 inputs to fully connected layer""" in_dtype = ["uint8", "int8"] kernel_dtype = ["uint8", "int8"] kernel_zero_point = [-33, 10, 0] @@ -193,12 +191,12 @@ def test_invalid_parameters( kernel_dtype, kernel_zero_point, ): + """Tests fully connected layer with non int8 inputs""" in_shape = (2, 28) out_channels = 2 input_scale = 1 input_zero_point = 24 kernel_scale = [0.11, 0.0237] - in_min, in_max = get_range_for_dtype_str(in_dtype) kernel_shape = [out_channels, in_shape[1]] conv2d_kernel_shape = [1, 1, kernel_shape[0], kernel_shape[1]] diff --git a/tests/python/contrib/test_cmsisnn/test_generate_constants.py b/tests/python/contrib/test_cmsisnn/test_generate_constants.py index cded0f03566d4..e6faa1a243f5c 100644 --- a/tests/python/contrib/test_cmsisnn/test_generate_constants.py +++ b/tests/python/contrib/test_cmsisnn/test_generate_constants.py @@ -16,7 +16,6 @@ # under the License. """CMSIS-NN integration tests: generate_constants pass""" -import itertools import math import numpy as np import pytest @@ -25,9 +24,8 @@ from tvm import relay from tvm.relay.op.contrib import cmsisnn -from utils import ( +from .utils import ( make_module, - get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, make_qnn_relu, @@ -43,6 +41,8 @@ def quantize_scale(scale): class CheckGeneratedConstants(tvm.relay.ExprVisitor): + """Provides methods to compare against expected quantization parameters""" + def __init__(self, enable_bias, multiplier, shift): super().__init__() self.num_constant_args_ = 0 @@ -53,7 +53,6 @@ def __init__(self, enable_bias, multiplier, shift): def visit_call(self, call): super().visit_call(call) if isinstance(call.op, tvm.ir.expr.GlobalVar): - # extern_fn_call(input, weight, multiplier, weight_scale, bias_optional, input_scale, shift) multiplier = call.args[2] shift = call.args[6] if self.enable_bias_ else call.args[5] assert isinstance( @@ -107,7 +106,7 @@ def make_model( weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) rng = np.random.default_rng(12321) - w = tvm.nd.array( + weight = tvm.nd.array( rng.integers( np.iinfo(kernel_dtype).min, high=np.iinfo(kernel_dtype).max, @@ -115,7 +114,7 @@ def make_model( dtype=kernel_dtype, ) ) - weight_const = relay.const(w, kernel_dtype) + weight_const = relay.const(weight, kernel_dtype) conv = relay.qnn.op.conv2d( a, weight_const, @@ -133,8 +132,8 @@ def make_model( padding=p, out_dtype="int32", ) - b = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) - bias_const = relay.const(b, "int32") + bias = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) + bias_const = relay.const(bias, "int32") last_op = relay.nn.bias_add(conv, bias_const, axis=3) if enable_bias else conv requant_input_sc = [sc * input_scale for sc in kernel_scale] last_op = relay.qnn.op.requantize( @@ -146,7 +145,7 @@ def make_model( out_dtype=dtype, ) last_op = make_qnn_relu(last_op, relu_type, output_scale, output_zero_point, dtype) - params = {"w": w, "b": b} + params = {"w": weight, "b": bias} return last_op, params @@ -163,6 +162,7 @@ def test_op_int8( kernel_scale, out_channels, ): + """Tests for CMSIS-NN constants when the dtype is int8""" ifm_shape = (1, 28, 28, 3) padding = "VALID" strides = (1, 1) @@ -175,7 +175,6 @@ def test_op_int8( kernel_w = kernel_size[1] dtype = "int8" relu_type = "RELU" - in_min, in_max = get_range_for_dtype_str(dtype) weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) diff --git a/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py b/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py index d0a8547d32acd..c66f9d0e07260 100644 --- a/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py +++ b/tests/python/contrib/test_cmsisnn/test_invalid_graphs.py @@ -16,17 +16,14 @@ # under the License. """CMSIS-NN integration tests: Tests invalid graphs""" -import itertools import numpy as np -import pytest import tvm -from tvm import relay from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import ( AOT_USMP_CORSTONE300_RUNNER, ) -from utils import ( +from .utils import ( skip_if_no_reference_system, get_range_for_dtype_str, ) @@ -35,13 +32,14 @@ @skip_if_no_reference_system @tvm.testing.requires_cmsisnn def test_empty_function(): - ORIGINAL_MODEL = """ + """Test partitioned function without composite function""" + original_model = """ #[version = "0.0.5"] def @main(%data : Tensor[(16, 29), int8]) -> Tensor[(16, 29), int8] { add(%data, %data) } """ - CMSISNN_MODEL = """ + cmsisnn_model = """ #[version = "0.0.5"] def @tvmgen_default_cmsis_nn_main_1(%i1: Tensor[(16, 29), int8], Inline=1, Compiler="cmsis-nn", global_symbol="tvmgen_default_cmsis_nn_main_1", Primitive=1) -> Tensor[(16, 29), int8] { add(%i1, %i1) @@ -51,8 +49,8 @@ def @main(%data : Tensor[(16, 29), int8]) -> Tensor[(16, 29), int8] { %1 } """ - orig_mod = tvm.parser.fromtext(ORIGINAL_MODEL) - cmsisnn_mod = tvm.parser.fromtext(CMSISNN_MODEL) + orig_mod = tvm.parser.fromtext(original_model) + cmsisnn_mod = tvm.parser.fromtext(cmsisnn_model) params = {} # validate the output diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py index 3b1e2331f2ff5..6f9f3743a6226 100644 --- a/tests/python/contrib/test_cmsisnn/test_networks.py +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -17,8 +17,6 @@ """CMSIS-NN: testing with networks""" -import sys - import pytest import numpy as np @@ -26,20 +24,21 @@ from tvm import relay from tvm.contrib.download import download_testdata from tvm.relay.op.contrib import cmsisnn - -from utils import skip_if_no_reference_system, get_range_for_dtype_str from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import ( AOT_CORSTONE300_RUNNER, AOT_USMP_CORSTONE300_RUNNER, ) +from .utils import skip_if_no_reference_system, get_range_for_dtype_str - +# pylint: disable=import-outside-toplevel def _convert_to_relay( tflite_model_buf, input_data, input_node, ): + """Converts TFLite model to Relay module and params""" + def convert_to_list(x): if not isinstance(x, list): x = [x] @@ -62,9 +61,9 @@ def convert_to_list(x): shape_dict = {} dtype_dict = {} - for i, e in enumerate(input_node): - shape_dict[e] = input_data[i].shape - dtype_dict[e] = input_data[i].dtype.name + for i, name in enumerate(input_node): + shape_dict[name] = input_data[i].shape + dtype_dict[name] = input_data[i].dtype.name mod, params = relay.frontend.from_tflite( tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict @@ -78,8 +77,13 @@ def convert_to_list(x): @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("test_runner", [AOT_CORSTONE300_RUNNER, AOT_USMP_CORSTONE300_RUNNER]) def test_cnn_small(test_runner): + """Download a small network and tests TVM via CMSIS-NN output against TFLite output""" # download the model - base_url = "https://github.com/ARM-software/ML-zoo/raw/48a22ee22325d15d2371a6df24eb7d67e21dcc97/models/keyword_spotting/cnn_small/tflite_int8" + base_url = ( + "https://github.com/ARM-software/ML-zoo/raw/" + "48a22ee22325d15d2371a6df24eb7d67e21dcc97" + "/models/keyword_spotting/cnn_small/tflite_int8" + ) file_to_download = "cnn_s_quantized.tflite" file_saved = "cnn_s_quantized_15Dec2021.tflite" model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_saved) diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index 1fd280b7d81a1..6b719cdc9938e 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -16,7 +16,6 @@ # under the License. """CMSIS-NN integration tests: Conv2D""" -import itertools import numpy as np import pytest import tvm @@ -25,12 +24,10 @@ from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER -from utils import ( - skip_if_no_reference_system, +from .utils import ( make_module, get_range_for_dtype_str, get_same_padding, - get_conv2d_qnn_params, make_qnn_relu, assert_partitioned_function, assert_no_external_function, @@ -49,7 +46,9 @@ def make_model( relu_type="RELU", layout="NHWC", ): - """Return a model and any parameters it may have, all parameters are defaulted to known good values""" + """Return a model and any parameters it may have, + all parameters are defaulted to known good values + """ op = relay.var("input", shape=shape, dtype=dtype) pad_ = (0, 0, 0, 0) if padding == "SAME": @@ -61,12 +60,12 @@ def make_model( pad_value=zero_point, pad_mode="constant", ) - if pool_op == relay.nn.avg_pool2d: + if pool_op.__name__ == relay.nn.avg_pool2d.__name__: op = relay.cast(op, "int32") op = pool_op( op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout ) - if pool_op == relay.nn.avg_pool2d: + if pool_op.__name__ == relay.nn.avg_pool2d.__name__: op = relay.cast(op, dtype) op = make_qnn_relu(op, relu_type, scale, zero_point, dtype) return op @@ -91,6 +90,7 @@ def test_op_int8( zero_point, scale, ): + """Tests QNN pooling op for int8 inputs""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -138,6 +138,7 @@ def test_op_int8( @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d]) def test_invalid_datatype(op): + """Checks CMSIS-NN partitioning for non int8 dtype""" model = make_model(pool_op=op, dtype="int64") orig_mod = make_module(model) @@ -148,6 +149,7 @@ def test_invalid_datatype(op): @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d]) def test_invalid_batch_size(op): + """Checks CMSIS-NN partitioning when batch size is not 1""" model = make_model( pool_op=op, shape=(2, 28, 28, 12), @@ -161,6 +163,7 @@ def test_invalid_batch_size(op): @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d]) def test_invalid_layout(op): + """Checks CMSIS-NN partitioning when layout is not NHWC""" model = make_model(pool_op=op, layout="NCHW") orig_mod = make_module(model) diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py index 35bdabf3171c4..557a65aeffcaf 100644 --- a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py +++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py @@ -16,10 +16,7 @@ # under the License. """CMSIS-NN integration tests: scalar_to_tensor_constant pass""" -import sys - import numpy as np -import pytest import tvm import tvm.testing from tvm import relay @@ -56,6 +53,8 @@ def make_binary_op( class CheckFunctionsForConstants(tvm.relay.ExprVisitor): + """Provides method to test number of scalar constants present in a function""" + def __init__(self): super().__init__() self.num_constants_ = 0 @@ -66,7 +65,7 @@ def visit_call(self, call): if isinstance(arg, relay.Constant) and arg.data.numpy().ndim > 0: self.num_constants_ += 1 - def check_num_constants(self, func): + def check_num_constants(self): assert self.num_constants_ == 0, "Functions should not have constant arguments in Calls" @@ -84,44 +83,45 @@ def set_composite_func_attr(func, name): @tvm.testing.requires_cmsisnn def test_single_scalar_position_0(): + """Tests conversion to tensor constant when first operand is a scalar""" dtype = "int8" shape = (8, 8) - x0 = generate_variable("x0", None, dtype) - x1 = generate_variable("x1", shape, dtype) - z1 = make_binary_op( + operand0 = generate_variable("operand0", None, dtype) + operand1 = generate_variable("operand1", shape, dtype) + binary_op = make_binary_op( relay.qnn.op.add, - x0, - x1, + operand0, + operand1, input_0_scale=0.0128, input_0_zero_point=32, input_1_scale=0.256, input_1_zero_point=-64, ) - lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) - lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype)) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add") - y0 = relay.expr.const(3, dtype) - y1 = relay.var("y1", shape=shape, dtype=dtype) - c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([y1], c0, relay.TensorType(shape, dtype)) + arg0 = relay.expr.const(3, dtype) + arg1 = relay.var("arg1", shape=shape, dtype=dtype) + call_local_func = relay.Call(local_func, [arg0, arg1]) + extern_func = relay.Function([arg1], call_local_func, relay.TensorType(shape, dtype)) x = relay.var("x", shape=shape, dtype=dtype) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - c = relay.Call(ev, [x]) - mf = relay.Function([x], c, relay.TensorType(shape, dtype)) - mv = relay.GlobalVar("main") + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, [x]) + main_func = relay.Function([x], call_extern_func, relay.TensorType(shape, dtype)) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) mod = relay.transform.InferType()(mod) check_for_constants = CheckFunctionsForConstants() - check_for_constants.visit_call(mod[ev].body) + check_for_constants.visit_call(mod[global_var].body) assert ( check_for_constants.num_constants_ == 1 ), "Scalar constant wasn't converted into tensor constant" @@ -129,44 +129,45 @@ def test_single_scalar_position_0(): @tvm.testing.requires_cmsisnn def test_single_scalar_position_1(): + """Tests conversion to tensor constant when second operand is a scalar""" dtype = "int8" shape = (8, 8) - x0 = generate_variable("x0", shape, dtype) - x1 = generate_variable("x1", None, dtype) - z1 = make_binary_op( + operand0 = generate_variable("operand0", shape, dtype) + operand1 = generate_variable("operand1", None, dtype) + binary_op = make_binary_op( relay.qnn.op.add, - x0, - x1, + operand0, + operand1, input_0_scale=0.0128, input_0_zero_point=32, input_1_scale=0.256, input_1_zero_point=-64, ) - lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) - lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype)) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add") - y0 = relay.var("y0", shape=shape, dtype=dtype) - y1 = relay.expr.const(3, dtype) - c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([y0], c0, relay.TensorType(shape, dtype)) + arg0 = relay.var("arg0", shape=shape, dtype=dtype) + arg1 = relay.expr.const(3, dtype) + call_local_func = relay.Call(local_func, [arg0, arg1]) + extern_func = relay.Function([arg0], call_local_func, relay.TensorType(shape, dtype)) x = relay.var("x", shape=shape, dtype=dtype) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - c = relay.Call(ev, [x]) - mf = relay.Function([x], c, relay.TensorType(shape, dtype)) - mv = relay.GlobalVar("main") + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, [x]) + main_func = relay.Function([x], call_extern_func, relay.TensorType(shape, dtype)) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) mod = relay.transform.InferType()(mod) check_for_constants = CheckFunctionsForConstants() - check_for_constants.visit_call(mod[ev].body) + check_for_constants.visit_call(mod[global_var].body) assert ( check_for_constants.num_constants_ == 1 ), "Scalar constant wasn't converted into tensor constant" @@ -174,83 +175,85 @@ def test_single_scalar_position_1(): @tvm.testing.requires_cmsisnn def test_primary_operands_all_scalars(): + """Tests conversion to tensor constants all operands are scalars""" dtype = "int8" shape = None - x0 = generate_variable("x0", None, dtype) - x1 = generate_variable("x1", None, dtype) - z1 = make_binary_op( + operand0 = generate_variable("operand0", None, dtype) + operand1 = generate_variable("operand1", None, dtype) + binary_op = make_binary_op( relay.qnn.op.add, - x0, - x1, + operand0, + operand1, input_0_scale=0.0128, input_0_zero_point=32, input_1_scale=0.256, input_1_zero_point=-64, ) - lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) - lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype)) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add") - y0 = relay.expr.const(7, dtype) - y1 = relay.expr.const(3, dtype) - c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([], c0, relay.TensorType(shape, dtype)) + arg0 = relay.expr.const(7, dtype) + arg1 = relay.expr.const(3, dtype) + call_local_func = relay.Call(local_func, [arg0, arg1]) + extern_func = relay.Function([], call_local_func, relay.TensorType(shape, dtype)) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - c = relay.Call(ev, []) - mf = relay.Function([], c, relay.TensorType(shape, dtype)) - mv = relay.GlobalVar("main") + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, []) + main_func = relay.Function([], call_extern_func, relay.TensorType(shape, dtype)) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) new_mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod[ev].body, new_mod[ev].body) + assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) @tvm.testing.requires_cmsisnn def test_all_primary_operands_tensor_constants(): + """Tests conversion to tensor constants all operands are tensors""" dtype = "int8" shape = (1, 3, 3, 32) - x0 = generate_variable("x0", shape, dtype) - x1 = generate_variable("x1", shape, dtype) - z1 = make_binary_op( + operand0 = generate_variable("operand0", shape, dtype) + operand1 = generate_variable("operand1", shape, dtype) + binary_op = make_binary_op( relay.qnn.op.add, - x0, - x1, + operand0, + operand1, input_0_scale=0.0128, input_0_zero_point=32, input_1_scale=0.256, input_1_zero_point=-64, ) - lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) - lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype)) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add") rng = np.random.default_rng(12345) - y0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) - y1 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) - c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([], c0, relay.TensorType(shape, dtype)) + arg0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) + arg1 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) + call_local_func = relay.Call(local_func, [arg0, arg1]) + extern_func = relay.Function([], call_local_func, relay.TensorType(shape, dtype)) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) - c = relay.Call(ev, []) - mf = relay.Function([], c, relay.TensorType(shape, dtype)) - mv = relay.GlobalVar("main") + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, []) + main_func = relay.Function([], call_extern_func, relay.TensorType(shape, dtype)) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) new_mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod[ev].body, new_mod[ev].body) + assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) @tvm.testing.requires_cmsisnn @@ -258,26 +261,28 @@ def test_non_cmsisnn_ext_func(): """Non CMSISNN functions should not be altered.""" def get_mod(): - x1 = relay.var("x1", shape=None) - x2 = relay.var("x2", shape=None) - z1 = x1 + x2 - lf = relay.Function([x1, x2], z1, relay.TensorType((), "float32")) - lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") - - y0 = relay.expr.const(5, "float32") - y1 = relay.expr.const(3, "float32") - c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([], c0, relay.TensorType((), "float32")) - - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, "foo", ev.name_hint) - c = relay.Call(ev, []) - mf = relay.Function([], c, relay.TensorType((), "float32")) - mv = relay.GlobalVar("main") + operand1 = relay.var("operand1", shape=None) + operand2 = relay.var("operand2", shape=None) + binary_op = operand1 + operand2 + local_func = relay.Function( + [operand1, operand2], binary_op, relay.TensorType((), "float32") + ) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add") + + arg0 = relay.expr.const(5, "float32") + arg1 = relay.expr.const(3, "float32") + call_local_func = relay.Call(local_func, [arg0, arg1]) + extern_func = relay.Function([], call_local_func, relay.TensorType((), "float32")) + + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "foo", global_var.name_hint) + call_extern_func = relay.Call(global_var, []) + main_func = relay.Function([], call_extern_func, relay.TensorType((), "float32")) + main_var = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef - mod[mv] = mf + mod[global_var] = extern_func + mod[main_var] = main_func mod = relay.transform.InferType()(mod) return mod diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index 840d0e6f4436d..c6d2e4ec45371 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -16,8 +16,6 @@ # under the License. """CMSIS-NN integration tests: Softmax""" - -import sys import itertools import numpy as np @@ -26,16 +24,16 @@ import tvm.testing from tvm import relay from tvm.relay.op.contrib import cmsisnn +from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER -from utils import ( +from .utils import ( skip_if_no_reference_system, make_module, get_range_for_dtype_str, assert_partitioned_function, assert_no_external_function, ) -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data -from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER def make_model( @@ -62,6 +60,7 @@ def make_model( @pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]]) @tvm.testing.requires_cmsisnn def test_op_int8(zero_point, scale): + """Tests int8 QNN Softmax for CMSIS-NN""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -92,6 +91,7 @@ def test_op_int8(zero_point, scale): def parameterize_for_invalid_model(test): + """Generates parameters for non int8 input and output of Softmax""" in_dtype = ["uint8", "int8"] out_dtype = ["uint8", "int8"] zero_point = [-128, 64] @@ -119,6 +119,7 @@ def parameterize_for_invalid_model(test): @parameterize_for_invalid_model @tvm.testing.requires_cmsisnn def test_invalid_parameters(in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale): + """Tests for non int8 input and output of Softmax""" model = make_model( [1, 16, 16, 3], in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale ) diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 83c67cd95b1c7..e69329ebc5a42 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -17,11 +17,9 @@ """CMSIS-NN functions for testing networks""" -import platform import math +from typing import List, Union, Tuple import numpy as np -import pytest -from typing import List, Dict, Optional, Any, Union, Tuple import tvm from tvm import relay @@ -52,6 +50,7 @@ def visit_call(self, call): def assert_partitioned_function(orig_mod, cmsisnn_mod): + """If kCompiler attribute is missing, this function raises assertion""" attrs = [ cmsisnn_mod[var.name_hint].attrs for var in cmsisnn_mod.get_global_vars() @@ -225,3 +224,5 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): ) if fused_activation_fn == "RELU": return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) + + raise ValueError("Invalid argument provided with fused_activation_fn") From af0128158c45683d03d3cd0a8aea5afd620794c7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 9 Jun 2022 15:34:32 -0500 Subject: [PATCH 083/181] [TIR][Schedule] Allow named block and buffer arguments in Schedule (#11624) * [Schedule] Allowed string argument as block arg This has previously been implemented for `Schedule.transform_layout` in https://github.com/apache/tvm/pull/11296, extending to allow for block arguments in all `Schedule` methods. This change was only made for arguments that must be a `BlockRV`. For arguments that may be either a `BlockRV` or another type (e.g. `Schedule.get_child_blocks` accepts either `BlockRV` or `LoopRV`), this sugar is not implemented, to avoid ambiguity. * [Schedule] Allowed string argument to Schedule.reindex Similar to https://github.com/apache/tvm/pull/11269, which added this functionality to `Schedule.transform_layout`. * CI test update --- python/tvm/tir/schedule/schedule.py | 112 ++++++++++++------ .../schedule/primitive/cache_read_write.cc | 9 +- .../test_tir_schedule_cache_read_write.py | 94 ++++++++------- .../unittest/test_tir_schedule_compute_at.py | 78 ++++++------ .../test_tir_schedule_compute_inline.py | 106 +++++++++-------- .../unittest/test_tir_schedule_reduction.py | 10 +- .../unittest/test_tir_schedule_reindex.py | 32 +++-- .../unittest/test_tir_schedule_sampling.py | 10 +- .../unittest/test_tir_schedule_set_scope.py | 9 +- .../test_tir_schedule_storage_align.py | 6 +- .../test_tir_schedule_transform_layout.py | 32 +++-- .../unittest/test_tir_schedule_utilities.py | 20 ++-- 12 files changed, 291 insertions(+), 227 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d225280b655f7..d29495c430076 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -373,14 +373,14 @@ def sample_perfect_tile( @type_checked def sample_compute_location( self, - block: BlockRV, + block: Union[BlockRV, str], decision: Optional[int] = None, ) -> LoopRV: """Sample a compute-at location of the given block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block whose compute-at location is to be sampled decision : Optional[int] The sampling decision @@ -390,6 +390,8 @@ def sample_compute_location( result : LoopRV The sampled loop where the input block is to be computed at """ + block = self._normalize_block_arg(block) + return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member self, block, @@ -425,12 +427,12 @@ def get_block( ) @type_checked - def get_loops(self, block: BlockRV) -> List[LoopRV]: + def get_loops(self, block: Union[BlockRV, str]) -> List[LoopRV]: """Get the parent loops of the block in its scope, from outer to inner Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The query block Returns @@ -438,6 +440,7 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: loops : List[LoopRV] A list of loops above the given block in its scope, from outer to inner """ + block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore # pylint: disable=no-member @type_checked @@ -457,12 +460,12 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore # pylint: disable=no-member @type_checked - def get_producers(self, block: BlockRV) -> List[BlockRV]: + def get_producers(self, block: Union[BlockRV, str]) -> List[BlockRV]: """Get the producers of a specific block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block in the query Returns @@ -470,15 +473,16 @@ def get_producers(self, block: BlockRV) -> List[BlockRV]: producers : List[BlockRV] A list of producers of the given block """ + block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore # pylint: disable=no-member @type_checked - def get_consumers(self, block: BlockRV) -> List[BlockRV]: + def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: """Get the consumers of a specific block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block in the query Returns @@ -486,6 +490,7 @@ def get_consumers(self, block: BlockRV) -> List[BlockRV]: consumers : List[BlockRV] A list of consumers of the given block """ + block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member ########## Schedule: Transform loops ########## @@ -970,7 +975,9 @@ def after_unroll(a: T.handle, b: T.handle) -> None: ########## Schedule: Insert cache stages ########## @type_checked - def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV: + def cache_read( + self, block: Union[BlockRV, str], read_buffer_index: int, storage_scope: str + ) -> BlockRV: """Create a block that reads a buffer region into a read cache. It requires: 1) There is at most one block who write the buffer in the scope. @@ -979,7 +986,7 @@ def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The consumer block of the target buffer. read_buffer_index: int @@ -1036,12 +1043,15 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: B[vi, vj] = A_local[vi, vj] * 2.0 """ + block = self._normalize_block_arg(block) return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member self, block, read_buffer_index, storage_scope ) @type_checked - def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV: + def cache_write( + self, block: Union[BlockRV, str], write_buffer_index: int, storage_scope: str + ) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: 1) There is only one block who write the buffer in the scope. @@ -1050,7 +1060,7 @@ def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: st Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The producer block of the target buffer. write_buffer_index: int @@ -1108,12 +1118,17 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B[vi, vj] = B_local[vi, vj] """ + block = self._normalize_block_arg(block) return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member self, block, write_buffer_index, storage_scope ) @type_checked - def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV: + def reindex( + self, + block: Union[BlockRV, str], + buffer: Union[Tuple[str, int], str, Buffer], + ) -> BlockRV: """Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes the buffer. It requires: @@ -1122,12 +1137,27 @@ def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> Parameters ---------- - block: BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_index_type : str - Type of the buffer index, "read" or "write" + block : Union[BlockRV, str] + + The block that accesses the target buffer. If a string, + this must uniquely identify a block. + + buffer: Union[Tuple[str,int], Buffer, str] + + The buffer to be transformed, or a specification of how to + identify the buffer to be transformed. + + If `buffer` if a tuple of ``(str,int)``, the first item + should be either "read" or "write", and the second item is + an index into the block's read or write regions. + + If `buffer` is a string, it is the name of the buffer, + which must exist within the reads/writes of the block. In + addition, the reads/writes of the block may not contain + more than one buffer with this name. + + If `buffer` is a Buffer object, it must exist within the + reads/writes of the block. Returns ------- @@ -1157,7 +1187,7 @@ def before_reindex( sch = tir.Schedule(before_reindex) block = sch.get_block("B") - sch.reindex(block, 0, "read) + sch.reindex(block, ("read", 0)) After applying reindex, the IR becomes: @@ -1179,6 +1209,8 @@ def after_reindex( B[vi, vj] = A_reindex[vi, vj] * 2.0 """ + block = self._normalize_block_arg(block) + buffer_index_type, buffer_index, _ = self._normalize_buffer_arg(block, buffer) assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member @@ -1190,7 +1222,7 @@ def after_reindex( @type_checked def compute_at( self, - block: BlockRV, + block: Union[BlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, ) -> None: @@ -1213,7 +1245,7 @@ def compute_at( Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be moved loop: LoopRV @@ -1273,6 +1305,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member self, block, @@ -1283,7 +1316,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: @type_checked def reverse_compute_at( self, - block: BlockRV, + block: Union[BlockRV, str], loop: LoopRV, preserve_unit_loops: bool = False, ) -> None: @@ -1303,7 +1336,7 @@ def reverse_compute_at( Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be moved loop: LoopRV @@ -1363,6 +1396,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleReverseComputeAt( # type: ignore # pylint: disable=no-member self, block, @@ -1371,7 +1405,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: ) @type_checked - def compute_inline(self, block: BlockRV) -> None: + def compute_inline(self, block: Union[BlockRV, str]) -> None: """Inline a block into its consumer(s). It requires: 1) The block is a complete non-root block, which only produces one buffer @@ -1386,7 +1420,7 @@ def compute_inline(self, block: BlockRV) -> None: Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be inlined to its consumer(s) Examples @@ -1432,10 +1466,11 @@ def after_inline(a: T.handle, c: T.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member @type_checked - def reverse_compute_inline(self, block: BlockRV) -> None: + def reverse_compute_inline(self, block: Union[BlockRV, str]) -> None: """Inline a block into its only producer. It requires: 1) The block is a complete non-root block, which only produces and consumes one buffer @@ -1453,7 +1488,7 @@ def reverse_compute_inline(self, block: BlockRV) -> None: Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be inlined to its producer Examples @@ -1499,12 +1534,13 @@ def after_inline(a: T.handle, c: T.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member ########## Schedule: Reduction ########## @type_checked - def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: + def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV: """Decompose a reduction block into two separate blocks. a) The init block, which is translated from the init statement of the reduction block; @@ -1523,7 +1559,7 @@ def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The reduction block to be decomposed loop : LoopRV The loop above which the init block is inserted before. @@ -1578,6 +1614,7 @@ def after_decompose(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] """ + block = self._normalize_block_arg(block) return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member @type_checked @@ -1734,7 +1771,7 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: @type_checked def storage_align( # pylint: disable=too-many-arguments self, - block: BlockRV, + block: Union[BlockRV, str], buffer_index: int, axis: int, factor: int, @@ -1747,7 +1784,7 @@ def storage_align( # pylint: disable=too-many-arguments Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The producer block of the buffer. buffer_index : int The index of the buffer in block's write region. @@ -1812,18 +1849,19 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: ---- Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member self, block, buffer_index, axis, factor, offset ) @type_checked - def set_scope(self, block: BlockRV, buffer_index: int, storage_scope: str) -> None: + def set_scope(self, block: Union[BlockRV, str], buffer_index: int, storage_scope: str) -> None: """Set the storage scope of a buffer, where the buffer is specified by the a block and a write-index Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The producer block of the buffer buffer_index : int The index of the buffer in block's write region @@ -1883,6 +1921,7 @@ def after_set_scope( ---- Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ + block = self._normalize_block_arg(block) _ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member self, block, buffer_index, storage_scope ) @@ -2418,14 +2457,14 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> @type_checked def transform_block_layout( self, - block: BlockRV, + block: Union[BlockRV, str], index_map: Union[IndexMap, Callable], ) -> None: """Apply a transformation represented by IndexMap to block Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The block to be transformed index_map : Union[IndexMap, Callable] @@ -2470,6 +2509,7 @@ def after_transform_block_layout( vi, = T.axis.remap("S", [i]) B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0 """ + block = self._normalize_block_arg(block) if callable(index_map): index_map = IndexMap.from_func(index_map) _ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index c96f88e1f6333..5a8d452f14b85 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1241,11 +1241,10 @@ struct ReIndexTraits : public UnpackedInstTraits { Integer buffer_index_type) { PythonAPICall py("reindex"); py.Input("block", block); - py.Input("buffer_index", buffer_index); - py.Input("buffer_index_type", '"' + - std::string(BufferIndexType2Str( - static_cast(buffer_index_type->value))) + - '"'); + std::ostringstream os; + os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\", " << buffer_index << ")"; + py.Input("buffer", os.str()); py.SingleOutput(outputs); return py.Str(); } diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index ef306b2c49290..5cd39c7ddaeb6 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -741,13 +741,15 @@ def block_predicate_cache_write_output_buf() -> None: ########## Testcases for cache_read ########## +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_cache_read_elementwise(): + +def test_cache_read_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") - cached_a = sch.cache_read(block_b, 0, "global") - cached_b = sch.cache_read(block_c, 0, "local") + cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global") + cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local") assert sch.get(cached_a) == sch.get(sch.get_block("A_global")) assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) assert sch.get(block_b) == sch.get(sch.get_block("B")) @@ -756,74 +758,74 @@ def test_cache_read_elementwise(): verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_read_under_scope(): +def test_cache_read_under_scope(use_block_name): sch = tir.Schedule(access_under_scope, debug_mask="all") - block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_b = "B" if use_block_name else sch.get_block("B") + block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_b, 0, "local") sch.cache_read(block_c, 0, "global") tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=access_under_scope) -def test_cache_read_opaque_access(): +def test_cache_read_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") - block = sch.get_block("load_store") + block = "load_store" if use_block_name else sch.get_block("load_store") sch.cache_read(block, 0, "global") tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=opaque_access) -def test_cache_read_location(): +def test_cache_read_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.cache_read(block_b, 0, "global") tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) -def test_continuous_cache_read(): +def test_continuous_cache_read(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.cache_read(block_c, 0, "shared") sch.cache_read(block_c, 0, "local") tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_read_with_block_predicate(): +def test_cache_read_with_block_predicate(use_block_name): sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = sch.get_block("consumer") + block = "consumer" if use_block_name else sch.get_block("consumer") sch.cache_read(block, 0, "shared") tvm.ir.assert_structural_equal(block_predicate_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) -def test_cache_read_non_int32_shape(): +def test_cache_read_non_int32_shape(use_block_name): sch = tir.Schedule(elementwise_shape_int64, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.cache_read(block_b, 0, "global") tvm.ir.assert_structural_equal(cache_read_shape_int64, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64) -def test_cache_read_fail_multi_producer(): +def test_cache_read_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 0, "global") -def test_cache_read_fail_index_out_of_bound(): +def test_cache_read_fail_index_out_of_bound(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 1, "global") -def test_cache_read_fail_invalid_storage_scope(): +def test_cache_read_fail_invalid_storage_scope(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_read(block_b, 0, "test_scope") @@ -831,12 +833,12 @@ def test_cache_read_fail_invalid_storage_scope(): ########## Testcases for cache_write ########## -def test_cache_write_elementwise(): +def test_cache_write_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") - cached_b = sch.cache_write(block_b, 0, "local") - cached_c = sch.cache_write(block_c, 0, "global") + cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local") + cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global") assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) assert sch.get(cached_c) == sch.get(sch.get_block("C_global")) assert sch.get(block_b) == sch.get(sch.get_block("B")) @@ -845,10 +847,10 @@ def test_cache_write_elementwise(): verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_write_under_scope(): +def test_cache_write_under_scope(use_block_name): sch = tir.Schedule(access_under_scope, debug_mask="all") - block_a = sch.get_block("A") - block_b = sch.get_block("B") + block_a = "A" if use_block_name else sch.get_block("A") + block_b = "B" if use_block_name else sch.get_block("B") block_scope = sch.get_block("scope") sch.cache_write(block_a, 0, "local") sch.cache_write(block_b, 0, "global") @@ -857,11 +859,11 @@ def test_cache_write_under_scope(): verify_trace_roundtrip(sch=sch, mod=access_under_scope) -def test_cache_write_opaque_access(): +def test_cache_write_opaque_access(use_block_name): sch = tir.Schedule(opaque_access, debug_mask="all") - block_store = sch.get_block("load_store") - block_opaque = sch.get_block("opaque") - block_match_buffer = sch.get_block("match_buffer") + block_store = "load_store" if use_block_name else sch.get_block("load_store") + block_opaque = "opaque" if use_block_name else sch.get_block("opaque") + block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer") sch.cache_write(block_store, 0, "global") sch.cache_write(block_opaque, 0, "global") sch.cache_write(block_match_buffer, 0, "global") @@ -869,58 +871,58 @@ def test_cache_write_opaque_access(): verify_trace_roundtrip(sch=sch, mod=opaque_access) -def test_cache_write_location(): +def test_cache_write_location(use_block_name): sch = tir.Schedule(func_multi_consumer, debug_mask="all") - block_a = sch.get_block("A") + block_a = "A" if use_block_name else sch.get_block("A") sch.cache_write(block_a, 0, "global") tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) -def test_continuous_cache_write(): +def test_continuous_cache_write(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.cache_write(block_b, 0, "shared") sch.cache_write(block_b, 0, "local") tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_cache_write_with_block_predicate(): +def test_cache_write_with_block_predicate(use_block_name): # cache write for intermediate buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = sch.get_block("producer") + block = "producer" if use_block_name else sch.get_block("producer") sch.cache_write(block, 0, "shared") tvm.ir.assert_structural_equal(block_predicate_cache_write_intermediate_buf, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) # cache write for external buffer sch = tir.Schedule(func_with_block_predicate, debug_mask="all") - block = sch.get_block("consumer") + block = "consumer" if use_block_name else sch.get_block("consumer") sch.cache_write(block, 0, "shared") tvm.ir.assert_structural_equal(block_predicate_cache_write_output_buf, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate) -def test_cache_write_fail_multi_producer(): +def test_cache_write_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") - block_a0 = sch.get_block("A0") - block_a1 = sch.get_block("A1") + block_a0 = "A0" if use_block_name else sch.get_block("A0") + block_a1 = "A1" if use_block_name else sch.get_block("A1") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_a0, 0, "global") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_a1, 0, "global") -def test_cache_write_fail_index_out_of_bound(): +def test_cache_write_fail_index_out_of_bound(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_b, 1, "global") -def test_cache_write_fail_invalid_storage_scope(): +def test_cache_write_fail_invalid_storage_scope(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.cache_write(block_b, 0, "test_scope") diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 3772d9a4e0fec..0c20a4783ca02 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1052,17 +1052,19 @@ def static_bound_after_compute_at(A: T.Buffer[(32, 1), "float32"], C: T.Buffer[( # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_compute_at_two_elementwise(): + +def test_compute_at_two_elementwise(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") - loop, _ = sch.get_loops(sch.get_block("C")) + block = "B" if use_block_name else sch.get_block("B") + loop, _ = sch.get_loops("C" if use_block_name else sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=True) tvm.ir.assert_structural_equal(two_elementwise_after_compute_at, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_compute_at_blockized_1(): +def test_compute_at_blockized_1(use_block_name): sch = tir.Schedule(blockized_1, debug_mask="all") block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_outer")) @@ -1071,7 +1073,7 @@ def test_compute_at_blockized_1(): verify_trace_roundtrip(sch=sch, mod=blockized_1) -def test_compute_at_blockized_2(): +def test_compute_at_blockized_2(use_block_name): sch = tir.Schedule(blockized_2, debug_mask="all") block = sch.get_block("B_outer") _, loop, _, _ = sch.get_loops(sch.get_block("C")) @@ -1080,7 +1082,7 @@ def test_compute_at_blockized_2(): verify_trace_roundtrip(sch=sch, mod=blockized_2) -def test_compute_at_cuda_matmul_0(): +def test_compute_at_cuda_matmul_0(use_block_name): sch = tir.Schedule(cuda_matmul_0, debug_mask="all") block = sch.get_block("C") _, _, _, _, _, loop, _, _ = sch.get_loops(sch.get_block("C_local")) @@ -1089,7 +1091,7 @@ def test_compute_at_cuda_matmul_0(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_0) -def test_compute_at_cuda_matmul_1(): +def test_compute_at_cuda_matmul_1(use_block_name): sch = tir.Schedule(cuda_matmul_1, debug_mask="all") block = sch.get_block("A_shared_local") _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1098,7 +1100,7 @@ def test_compute_at_cuda_matmul_1(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_1) -def test_compute_at_cuda_matmul_2(): +def test_compute_at_cuda_matmul_2(use_block_name): sch = tir.Schedule(cuda_matmul_2, debug_mask="all") block = sch.get_block("B_shared_local") _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1107,7 +1109,7 @@ def test_compute_at_cuda_matmul_2(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_2) -def test_compute_at_cuda_matmul_3(): +def test_compute_at_cuda_matmul_3(use_block_name): sch = tir.Schedule(cuda_matmul_3, debug_mask="all") block = sch.get_block("A_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1116,7 +1118,7 @@ def test_compute_at_cuda_matmul_3(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_3) -def test_compute_at_cuda_matmul_4(): +def test_compute_at_cuda_matmul_4(use_block_name): sch = tir.Schedule(cuda_matmul_4, debug_mask="all") block = sch.get_block("B_shared") _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) @@ -1125,7 +1127,7 @@ def test_compute_at_cuda_matmul_4(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4) -def test_compute_at_reduction_block(): +def test_compute_at_reduction_block(use_block_name): sch = tir.Schedule(multi_reduction, debug_mask="all") block = sch.get_block("B") (loop,) = sch.get_loops(sch.get_block("C")) @@ -1134,7 +1136,7 @@ def test_compute_at_reduction_block(): verify_trace_roundtrip(sch=sch, mod=multi_reduction) -def test_compute_at_tiled_pooling_read_cache(): +def test_compute_at_tiled_pooling_read_cache(use_block_name): sch = tir.Schedule(tiled_pooling_read_cache, debug_mask="all") compute = sch.get_block("compute") _, w_o, _, _, _, _ = sch.get_loops(compute) @@ -1144,7 +1146,7 @@ def test_compute_at_tiled_pooling_read_cache(): verify_trace_roundtrip(sch=sch, mod=tiled_pooling_read_cache) -def test_compute_at_non_uniform_tiled_conv(): +def test_compute_at_non_uniform_tiled_conv(use_block_name): sch = tir.Schedule(non_uniform_tiled_conv, debug_mask="all") compute = sch.get_block("compute") sch.compute_at(sch.get_block("cache"), sch.get_loops(compute)[1]) @@ -1152,7 +1154,7 @@ def test_compute_at_non_uniform_tiled_conv(): verify_trace_roundtrip(sch=sch, mod=non_uniform_tiled_conv) -def test_compute_at_concat(): +def test_compute_at_concat(use_block_name): sch = tir.Schedule(concat_two_elemwise, debug_mask="all") concat = sch.get_block("T_concat") add1 = sch.get_block("T_add_1") @@ -1164,7 +1166,7 @@ def test_compute_at_concat(): verify_trace_roundtrip(sch=sch, mod=concat_two_elemwise) -def test_compute_at_tiled_repeat_op(): +def test_compute_at_tiled_repeat_op(use_block_name): sch = tir.Schedule(tiled_repeat_op, debug_mask="all") outer_ax, _ = sch.get_loops(sch.get_block("T_repeat")) sch.compute_at(sch.get_block("T_add"), outer_ax) @@ -1172,7 +1174,7 @@ def test_compute_at_tiled_repeat_op(): verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op) -def test_reverse_compute_at_tiled(): +def test_reverse_compute_at_tiled(use_block_name): sch = tir.Schedule(tiled, debug_mask="all") block = sch.get_block("C") _, _, loop, _ = sch.get_loops(sch.get_block("B")) @@ -1181,7 +1183,7 @@ def test_reverse_compute_at_tiled(): verify_trace_roundtrip(sch=sch, mod=tiled) -def test_reverse_compute_at_tiled_trivial_binding(): +def test_reverse_compute_at_tiled_trivial_binding(use_block_name): sch = tir.Schedule(tiled_trivial_binding, debug_mask="all") block = sch.get_block("C") _, _, loop, _ = sch.get_loops(sch.get_block("B")) @@ -1190,7 +1192,7 @@ def test_reverse_compute_at_tiled_trivial_binding(): verify_trace_roundtrip(sch=sch, mod=tiled_trivial_binding) -def test_reverse_compute_at_blockized_2(): +def test_reverse_compute_at_blockized_2(use_block_name): sch = tir.Schedule(blockized_2, debug_mask="all") block = sch.get_block("C") _, loop = sch.get_loops(sch.get_block("B_outer")) @@ -1199,7 +1201,7 @@ def test_reverse_compute_at_blockized_2(): verify_trace_roundtrip(sch=sch, mod=blockized_2) -def test_reverse_compute_at_factorized(): +def test_reverse_compute_at_factorized(use_block_name): sch = tir.Schedule(factorized, debug_mask="all") block = sch.get_block("B") _, loop, _, _ = sch.get_loops(sch.get_block("B_rf")) @@ -1208,7 +1210,7 @@ def test_reverse_compute_at_factorized(): verify_trace_roundtrip(sch=sch, mod=factorized) -def test_reverse_compute_at_floordiv_and_floormod_indices(): +def test_reverse_compute_at_floordiv_and_floormod_indices(use_block_name): sch = tir.Schedule(floordiv_and_floormod_indices, debug_mask="all") A = sch.get_block("A") B = sch.get_block("B") @@ -1219,7 +1221,7 @@ def test_reverse_compute_at_floordiv_and_floormod_indices(): verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices) -def test_read_out_of_bound(): +def test_read_out_of_bound(use_block_name): sch = tir.Schedule(read_out_of_bound, debug_mask="all") block = sch.get_block("B") (loop,) = sch.get_loops(sch.get_block("C")) @@ -1228,7 +1230,7 @@ def test_read_out_of_bound(): verify_trace_roundtrip(sch=sch, mod=read_out_of_bound) -def test_compact_dataflow(): +def test_compact_dataflow(use_block_name): sch = tir.Schedule(not_all_compact_data_flow, debug_mask="all") block = sch.get_block("B") _, loop = sch.get_loops(sch.get_block("C_1")) @@ -1237,7 +1239,7 @@ def test_compact_dataflow(): verify_trace_roundtrip(sch=sch, mod=not_all_compact_data_flow) -def test_compute_at_simplify_static_bound(): +def test_compute_at_simplify_static_bound(use_block_name): sch = tir.Schedule(static_bound, debug_mask="all") block = sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) @@ -1246,7 +1248,7 @@ def test_compute_at_simplify_static_bound(): verify_trace_roundtrip(sch=sch, mod=static_bound) -def test_compute_at_non_perfect_channel_group(): +def test_compute_at_non_perfect_channel_group(use_block_name): @T.prim_func def grouped_channel_bias( X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] @@ -1284,7 +1286,7 @@ def grouped_channel_bias_non_perfect_tiled( tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) -def test_fail_subtree_complete_block(): +def test_fail_subtree_complete_block(use_block_name): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0") loop, _ = sch.get_loops(sch.get_block("C")) @@ -1292,47 +1294,47 @@ def test_fail_subtree_complete_block(): sch.compute_at(block, loop) -def test_fail_not_in_same_scope(): +def test_fail_not_in_same_scope(use_block_name): sch = tir.Schedule(blockized_1, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C_inner")) with pytest.raises(tvm.tir.ScheduleError, match="same block scope"): sch.compute_at(block, loop) -def test_fail_loop_is_ancestor_of_block(): +def test_fail_loop_is_ancestor_of_block(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("B")) with pytest.raises(tvm.tir.ScheduleError, match="ancestor of block"): sch.compute_at(block, loop) -def test_fail_output_block(): +def test_fail_output_block(use_block_name): sch = tir.Schedule(tiled, debug_mask="all") - block = sch.get_block("C") + block = "C" if use_block_name else sch.get_block("C") loop, _, _, _ = sch.get_loops(sch.get_block("B")) with pytest.raises(tvm.tir.ScheduleError, match="output block"): sch.compute_at(block, loop) -def test_fail_all_consumers_under_loop(): +def test_fail_all_consumers_under_loop(use_block_name): sch = tir.Schedule(fail_all_consumers_under_loop, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) with pytest.raises(tvm.tir.ScheduleError, match="requires all the consumer"): sch.compute_at(block, loop) -def test_fail_all_producers_under_loop(): +def test_fail_all_producers_under_loop(use_block_name): sch = tir.Schedule(fail_all_producers_under_loop, debug_mask="all") - block = sch.get_block("D") + block = "D" if use_block_name else sch.get_block("D") loop, _ = sch.get_loops(sch.get_block("C")) with pytest.raises(tvm.tir.ScheduleError, match="requires all the producer"): sch.reverse_compute_at(block, loop) -def test_compute_at_int64_loop(): +def test_compute_at_int64_loop(use_block_name): def _create_prim_func(): n = te.var("n", dtype="int64") m = te.var("m", dtype="int64") @@ -1344,8 +1346,8 @@ def _create_prim_func(): mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_c = "C" if use_block_name else sch.get_block("C") + block_d = "D" if use_block_name else sch.get_block("D") i, _ = sch.get_loops(block_d) sch.compute_at(block_c, i) verify_trace_roundtrip(sch=sch, mod=mod) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 84fb88218997f..617e13db27f60 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -587,10 +587,12 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined( # pylint: enable=no-member,invalid-name,unused-variable +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_compute_inline_elementwise(): + +def test_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) @@ -598,9 +600,9 @@ def test_compute_inline_elementwise(): verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_compute_inline_under_loop(): +def test_compute_inline_under_loop(use_block_name): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) @@ -608,9 +610,9 @@ def test_compute_inline_under_loop(): verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) -def test_compute_inline_as_dce(): +def test_compute_inline_as_dce(use_block_name): sch = tir.Schedule(elementwise_standalone, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) @@ -618,9 +620,9 @@ def test_compute_inline_as_dce(): verify_trace_roundtrip(sch=sch, mod=elementwise_standalone) -def test_compute_inline_multi_consumer(): +def test_compute_inline_multi_consumer(use_block_name): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") block_c = sch.get_block("C") block_d = sch.get_block("D") sch.compute_inline(block_b) @@ -630,81 +632,81 @@ def test_compute_inline_multi_consumer(): verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) -def test_compute_inline_fail_multi_writer(): +def test_compute_inline_fail_multi_writer(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_reverse_compute_inline_elementwise(): +def test_reverse_compute_inline_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_reverse_compute_inline_under_loop(): +def test_reverse_compute_inline_under_loop(use_block_name): sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) -def test_reverse_compute_inline_fail_as_dce(): +def test_reverse_compute_inline_fail_as_dce(use_block_name): sch = tir.Schedule(elementwise_standalone, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_b) -def test_reverse_compute_inline_fail_multi_producer(): +def test_reverse_compute_inline_fail_multi_producer(use_block_name): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") - block_d = sch.get_block("D") + block_d = "D" if use_block_name else sch.get_block("D") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_d) -def test_reverse_compute_inline_fail_multi_reader(): +def test_reverse_compute_inline_fail_multi_reader(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) -def test_reverse_compute_multi_reverse_loads(): +def test_reverse_compute_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) -def test_reverse_compute_inline_affine_load(): +def test_reverse_compute_inline_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load) -def test_reverse_compute_inline_multi_affine_load(): +def test_reverse_compute_inline_multi_affine_load(use_block_name): sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load) -def test_reverse_compute_inline_affine_load_unit_iter(): +def test_reverse_compute_inline_affine_load_unit_iter(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal( elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"] @@ -712,9 +714,9 @@ def test_reverse_compute_inline_affine_load_unit_iter(): verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter) -def test_reverse_compute_inline_affine_load_unit_iter_simplified(): +def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name): sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal( elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"] @@ -723,10 +725,10 @@ def test_reverse_compute_inline_affine_load_unit_iter_simplified(): @pytest.mark.parametrize("reverse_order", [True, False]) -def test_reverse_compute_inline_affine_chain(reverse_order): +def test_reverse_compute_inline_affine_chain(use_block_name, reverse_order): sch = tir.Schedule(elementwise_reverse_affine_chain, debug_mask="all") - block_c = sch.get_block("C") - block_d = sch.get_block("D") + block_c = "C" if use_block_name else sch.get_block("C") + block_d = "D" if use_block_name else sch.get_block("D") if reverse_order: sch.reverse_compute_inline(block_d) sch.reverse_compute_inline(block_c) @@ -737,68 +739,68 @@ def test_reverse_compute_inline_affine_chain(reverse_order): verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_chain) -def test_reverse_compute_fail_non_affine_load(): +def test_reverse_compute_fail_non_affine_load(use_block_name): sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) -def test_reverse_compute_fail_multi_reverse_loads(): +def test_reverse_compute_fail_multi_reverse_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") - block_c = sch.get_block("C") + block_c = "C" if use_block_name else sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) -def test_opaque_access_load(): +def test_opaque_access_load(use_block_name): sch = tir.Schedule(opaque_access_load, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_opaque_access_store(): +def test_opaque_access_store(use_block_name): sch = tir.Schedule(opaque_access_store, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_buffer_matched(): +def test_buffer_matched(use_block_name): sch = tir.Schedule(buffer_matched, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) -def test_output_block(): +def test_output_block(use_block_name): sch = tir.Schedule(matmul_relu, debug_mask="all") block = sch.get_block("compute") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block) -def test_compute_inline_predicate(): +def test_compute_inline_predicate(use_block_name): sch = tir.Schedule(elementwise_predicate, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) -def test_compute_inline_multi_loads(): +def test_compute_inline_multi_loads(use_block_name): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") - block_b = sch.get_block("B") + block_b = "B" if use_block_name else sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) -def test_compute_inline_with_opaque_access(): +def test_compute_inline_with_opaque_access(use_block_name): """Test not rewrite opaque reads/writes after irrelavant compute inline""" sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all") - BB = sch.get_block("BB") + BB = "BB" if use_block_name else sch.get_block("BB") sch.compute_inline(BB) tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"]) @@ -810,10 +812,10 @@ def test_inline_block_with_init(): sch.compute_inline(block=block) -def test_compute_inline_opaque_access_with_tvm_access_ptr(): +def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): """Test opaque access with tvm_access_ptr after compute inline""" sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all") - compute = sch.get_block("compute") + compute = "compute" if use_block_name else sch.get_block("compute") sch.compute_inline(compute) tvm.ir.assert_structural_equal( exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"] diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index a8348afb457d5..f3503460e50ac 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -215,19 +215,21 @@ def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_reduction_decompose0(): + +def test_reduction_decompose0(use_block_name): s = tir.Schedule(matmul, debug_mask="all") - C = s.get_block("update") + C = "update" if use_block_name else s.get_block("update") i, j, k = s.get_loops(C) s.decompose_reduction(C, i) tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"]) verify_trace_roundtrip(s, mod=matmul) -def test_reduction_decompose1(): +def test_reduction_decompose1(use_block_name): s = tir.Schedule(rowsum_blockized, debug_mask="all") - blockized_B = s.get_block("blockized_B") + blockized_B = "blockized_B" if use_block_name else s.get_block("blockized_B") io, ko = s.get_loops(blockized_B) s.decompose_reduction(blockized_B, io) tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"]) diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py index 9b2e37a19813a..c6776b0c8a3e2 100644 --- a/tests/python/unittest/test_tir_schedule_reindex.py +++ b/tests/python/unittest/test_tir_schedule_reindex.py @@ -168,35 +168,43 @@ def multiple_read(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "f B[vi, vj] = A[vj, vi] + A[vi, vj] -def test_reindex_read_basic(): +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) +use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False, "buffer_name": True}) + + +def test_reindex_read_basic(use_block_name, use_buffer_name): sch = tir.Schedule(transpose_elementwise) - block = sch.get_block("B") - sch.reindex(block, 0, "read") + block = "B" if use_block_name else sch.get_block("B") + buf = "A" if use_buffer_name else ("read", 0) + sch.reindex(block, buf) tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=transpose_elementwise) -def test_conv2d_reindex_read(): +def test_conv2d_reindex_read(use_block_name, use_buffer_name): sch = tir.Schedule(conv2d_nhwc) - block = sch.get_block("conv2d_nhwc") - sch.reindex(block, 1, "read") + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") + buf = "Weight" if use_buffer_name else ("read", 1) + sch.reindex(block, buf) tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) -def test_matmul_reindex_write(): +def test_matmul_reindex_write(use_block_name, use_buffer_name): sch = tir.Schedule(matmul) - block = sch.get_block("matmul") - sch.reindex(block, 0, "write") + block = "matmul" if use_block_name else sch.get_block("matmul") + buf = "C" if use_buffer_name else ("write", 0) + sch.reindex(block, buf) tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=matmul) -def test_reindex_fail_multiple_read(): +def test_reindex_fail_multiple_read(use_block_name, use_buffer_name): sch = tir.Schedule(multiple_read) - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") + buf = "A" if use_buffer_name else ("read", 0) with pytest.raises(ScheduleError): - sch.reindex(block, 0, "read") + sch.reindex(block, buf) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 17f35ea8f72fe..0c2a3d27ffdb2 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -179,10 +179,16 @@ def test_sample_perfect_tile_composite(): verify_trace_roundtrip(sch, mod=elementwise) -def test_sample_compute_location(): +use_sugared_block = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + + +def test_sample_compute_location(use_sugared_block): n = 100 sch = tir.Schedule(tiled_conv2d_with_padding, seed=42, debug_mask="all") - pad_input = sch.get_block("PadInput") + if use_sugared_block: + pad_input = "PadInput" + else: + pad_input = sch.get_block("PadInput") decision_dict = dict() for _ in range(n): _ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py index 29c4880f77622..b2e8479462ebe 100644 --- a/tests/python/unittest/test_tir_schedule_set_scope.py +++ b/tests/python/unittest/test_tir_schedule_set_scope.py @@ -86,20 +86,21 @@ def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_set_scope(): +def test_set_scope(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') - s.set_scope(s.get_block("B"), 0, "shared") + s.set_scope('B' if use_block_name else s.get_block("B"), 0, "shared") tvm.ir.assert_structural_equal(element_wise_set_scope, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) -def test_set_scope_fail_on_output_buffer(): +def test_set_scope_fail_on_output_buffer(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') with pytest.raises(tvm.tir.ScheduleError): - s.set_scope(s.get_block("C"), 0, "shared") + s.set_scope('C' if use_block_name else s.get_block("C"), 0, "shared") def test_set_scope_fail_on_index_out_of_bound(): diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 3b699fd8f1b2d..072640c8f3af5 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -98,10 +98,12 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) -def test_storage_align(): +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + +def test_storage_align(use_block_name): func = element_wise s = tir.Schedule(func, debug_mask='all') - B = s.get_block("B") + B = 'B' if use_block_name else s.get_block("B") s.storage_align(B, 0, axis=0, factor=128, offset=127) tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index e184bc3f627c3..205bd5091268b 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -171,15 +171,13 @@ def conv2d_nhwc_transformed( # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on -use_sugared_transform = tvm.testing.parameter( - by_dict={"transform_layout": False, "transform_layout_sugared": True} -) +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): +def test_two_elementwise_transform_intermediate_buffer(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - if use_sugared_transform: + if use_block_name: sch.transform_layout( block="B", buffer="B", @@ -193,10 +191,10 @@ def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform): verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_input_buffer(use_sugared_transform): +def test_two_elementwise_transform_input_buffer(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - if use_sugared_transform: + if use_block_name: sch.transform_layout( index_map=packed_index_map_func, block="B", @@ -210,10 +208,10 @@ def test_two_elementwise_transform_input_buffer(use_sugared_transform): verify_trace_roundtrip(sch=sch, mod=two_elementwise) -def test_two_elementwise_transform_output_buffer(use_sugared_transform): +def test_two_elementwise_transform_output_buffer(use_block_name): sch = tir.Schedule(two_elementwise, debug_mask="all") - if use_sugared_transform: + if use_block_name: sch.transform_layout( index_map=packed_index_map_func, block="C", @@ -295,17 +293,17 @@ def summation_3d_split( tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) -def test_transform_block_layout_basic(): +def test_transform_block_layout_basic(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) tvm.ir.assert_structural_equal(elementwise_transformed, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) -def test_transform_block_layout_conv2d_nhwc(): +def test_transform_block_layout_conv2d_nhwc(use_block_name): sch = tir.Schedule(conv2d_nhwc, debug_mask="all") - block = sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") sch.transform_block_layout( block, lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh * 7 * 3 + rw * 3 + rc), @@ -314,16 +312,16 @@ def test_transform_block_layout_conv2d_nhwc(): verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) -def test_transform_block_layout_fail_non_affine(): +def test_transform_block_layout_fail_non_affine(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") - block = sch.get_block("B") + block = "B" if use_block_name else sch.get_block("B") with pytest.raises(tir.ScheduleError): sch.transform_block_layout(block, lambda i, j: (i + j,)) -def test_transform_block_layout_fail_mixed_iter_type(): +def test_transform_block_layout_fail_mixed_iter_type(use_block_name): sch = tir.Schedule(conv2d_nhwc, debug_mask="all") - block = sch.get_block("conv2d_nhwc") + block = "conv2d_nhwc" if use_block_name else sch.get_block("conv2d_nhwc") with pytest.raises(tir.ScheduleError): sch.transform_block_layout( block, diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 0d23d3f95211d..b7517aab7cd37 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -104,6 +104,8 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + def test_tir_schedule_creation(): # Tests: @@ -131,24 +133,24 @@ def test_tir_schedule_get_block(): assert block.same_as(matmul.body.block.body.body.body[1].body.block) -def test_tir_schedule_get_loops(): +def test_tir_schedule_get_loops(use_block_name): # Tests: # - Schedule.get_loops # - Schedule.get sch = tir.Schedule(matmul, debug_mask="all") - block_rv = sch.get_block(name="update") - i, j, k = sch.get_loops(block_rv) + block = "update" if use_block_name else sch.get_block(name="update") + i, j, k = sch.get_loops(block) assert sch.get(i).loop_var.name == "i" assert sch.get(j).loop_var.name == "j" assert sch.get(k).loop_var.name == "k" -def test_tir_schedule_copy_1(): +def test_tir_schedule_copy_1(use_block_name): # Tests: # - Schedule.copy sch_1 = tir.Schedule(matmul, debug_mask="all") block_rv = sch_1.get_block(name="update") - i, j, k = sch_1.get_loops(block_rv) + i, j, k = sch_1.get_loops(block="update" if use_block_name else block_rv) assert sch_1.get(i).loop_var.name == "i" assert sch_1.get(j).loop_var.name == "j" assert sch_1.get(k).loop_var.name == "k" @@ -218,9 +220,9 @@ def test_get_child_blocks(): assert s.get(update) == s.get(blocks[1]) -def test_get_producers(): +def test_get_producers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - block = sch.get_block("relu") + block = "relu" if use_block_name else sch.get_block("relu") (producer,) = sch.get_producers(block) assert tvm.ir.structural_equal( sch.get_sref(producer).stmt, @@ -229,9 +231,9 @@ def test_get_producers(): verify_trace_roundtrip(sch, mod=matmul_relu) -def test_get_consumers(): +def test_get_consumers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - block = sch.get_block("matmul") + block = "matmul" if use_block_name else sch.get_block("matmul") (consumer,) = sch.get_consumers(block) assert tvm.ir.structural_equal( sch.get_sref(consumer).stmt, From 6d557ffae2db64fcea127b5e34089d9bc8e74fb0 Mon Sep 17 00:00:00 2001 From: driazati <9407960+driazati@users.noreply.github.com> Date: Thu, 9 Jun 2022 15:01:48 -0700 Subject: [PATCH 084/181] [ci] Rebuild Docker images if necessary (#11329) This rebuilds Docker images and uses them in later stages in the same build. If the build is running on `main`, then the images are uploaded to Docker Hub automatically once the run is complete. Images are always rebuilt, but Docker Hub functions as a cache. If there have been no changes to `docker/` since the last available hash on Docker Hub, then the build will just use the images from Hub. --- Jenkinsfile | 393 ++++++++++++++++--------- jenkins/Build.groovy.j2 | 23 ++ jenkins/Deploy.groovy.j2 | 50 ++++ jenkins/DockerBuild.groovy.j2 | 240 ++++++--------- jenkins/Jenkinsfile.j2 | 3 + jenkins/Lint.groovy.j2 | 10 +- jenkins/Prepare.groovy.j2 | 11 + tests/python/ci/test_ci.py | 97 +++++- tests/scripts/cmd_utils.py | 21 +- tests/scripts/git_utils.py | 1 + tests/scripts/http_utils.py | 34 +++ tests/scripts/should_rebuild_docker.py | 154 ++++++++++ 12 files changed, 737 insertions(+), 300 deletions(-) create mode 100644 tests/scripts/http_utils.py create mode 100755 tests/scripts/should_rebuild_docker.py diff --git a/Jenkinsfile b/Jenkinsfile index 0205a1e7364fe..ec4cea52d67b3 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,7 +45,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-06-02T14:03:43.284817 +// Generated at 2022-06-09T09:42:12.430625 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> @@ -97,6 +97,7 @@ if (currentBuild.getBuildCauses().toString().contains('BranchIndexingCause')) { // Filenames for stashing between build and test steps s3_prefix = "tvm-jenkins-artifacts-prod/tvm/${env.BRANCH_NAME}/${env.BUILD_NUMBER}" + // General note: Jenkins has limits on the size of a method (or top level code) // that are pretty strict, so most usage of groovy methods in these templates // are purely to satisfy the JVM @@ -171,6 +172,17 @@ def docker_init(image) { """, label: 'Clean old Docker images', ) + + if (image.contains("amazonaws.com")) { + // If this string is in the image name it's from ECR and needs to be pulled + // with the right credentials + ecr_pull(image) + } else { + sh( + script: "docker pull ${image}", + label: 'Pull docker image', + ) + } } def should_skip_slow_tests(pr_number) { @@ -273,16 +285,50 @@ def prepare() { } } } -def build_image(image_name) { - hash = sh( +def ecr_push(full_name) { + aws_account_id = sh( returnStdout: true, - script: 'git log -1 --format=\'%h\'' + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' ).trim() - def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" - sh( - script: "${docker_build} ${image_name} --spec ${full_name}", - label: 'Build docker image' - ) + + def ecr_name = "${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com/${full_name}" + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -x + docker tag ${full_name} \$AWS_ECR_REPO/${full_name} + docker push \$AWS_ECR_REPO/${full_name} + """, + label: 'Upload image to ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } + return ecr_name +} + +def ecr_pull(full_name) { aws_account_id = sh( returnStdout: true, script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', @@ -290,153 +336,144 @@ def build_image(image_name) { ).trim() try { - // Use a credential so Jenkins knows to scrub the AWS account ID which is nice - // (but so we don't have to rely it being hardcoded in Jenkins) - withCredentials([string( - credentialsId: 'aws-account-id', - variable: '_ACCOUNT_ID_DO_NOT_USE', - )]) { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2']) { - sh( - script: ''' - set -x - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com - ''', - label: 'Log in to ECR' - ) - sh( - script: """ - set -x - docker tag ${full_name} \$AWS_ACCOUNT_ID.dkr.ecr.\$AWS_DEFAULT_REGION.amazonaws.com/${full_name} - docker push \$AWS_ACCOUNT_ID.dkr.ecr.\$AWS_DEFAULT_REGION.amazonaws.com/${full_name} - """, - label: 'Upload image to ECR' - ) - } + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) } } finally { - sh( - script: 'rm -f ~/.docker/config.json', - label: 'Clean up login credentials' - ) + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION=us-west-2', + "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } } +} + +def build_image(image_name) { + hash = sh( + returnStdout: true, + script: 'git log -1 --format=\'%h\'' + ).trim() + def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" sh( - script: "docker rmi ${full_name}", - label: 'Remove docker image' + script: "${docker_build} ${image_name} --spec ${full_name}", + label: 'Build docker image' ) + return ecr_push(full_name) } + def build_docker_images() { stage('Docker Image Build') { - // TODO in a follow up PR: Find ecr tag and use in subsequent builds - parallel 'ci-lint': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_lint') + parallel( + 'ci_arm': { + node('ARM') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_arm = build_image('ci_arm') + build_image('ci_arm') + } } - } - }, 'ci-cpu': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_cpu') + }, + 'ci_cpu': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_cpu = build_image('ci_cpu') + build_image('ci_cpu') + } } - } - }, 'ci-gpu': { - node('GPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_gpu') + }, + 'ci_gpu': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_gpu = build_image('ci_gpu') + build_image('ci_gpu') + } } - } - }, 'ci-qemu': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_qemu') + }, + 'ci_hexagon': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_hexagon = build_image('ci_hexagon') + build_image('ci_hexagon') + } } - } - }, 'ci-i386': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_i386') + }, + 'ci_i386': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_i386 = build_image('ci_i386') + build_image('ci_i386') + } } - } - }, 'ci-arm': { - node('ARM') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_arm') + }, + 'ci_lint': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_lint = build_image('ci_lint') + build_image('ci_lint') + } } - } - }, 'ci-wasm': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_wasm') + }, + 'ci_qemu': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_qemu = build_image('ci_qemu') + build_image('ci_qemu') + } } - } - }, 'ci-hexagon': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_hexagon') + }, + 'ci_wasm': { + node('CPU') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // ci_wasm = build_image('ci_wasm') + build_image('ci_wasm') + } } - } - } - } - // // TODO: Once we are able to use the built images, enable this step - // // If the docker images changed, we need to run the image build before the lint - // // can run since it requires a base docker image. Most of the time the images - // // aren't build though so it's faster to use the same node that checks for - // // docker changes to run the lint in the usual case. - // stage('Sanity Check (re-run)') { - // timeout(time: max_time, unit: 'MINUTES') { - // node('CPU') { - // ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/sanity") { - // init_git() - // sh ( - // script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", - // label: 'Run lint', - // ) - // } - // } - // } - // } -} - -// Run make. First try to do an incremental make from a previous workspace in hope to -// accelerate the compilation. If something is wrong, clean the workspace and then -// build from scratch. -def make(docker_type, path, make_flag) { - timeout(time: max_time, unit: 'MINUTES') { - try { - cmake_build(docker_type, path, make_flag) - // always run cpp test when build - } catch (hudson.AbortException ae) { - // script exited due to user abort, directly throw instead of retry - if (ae.getMessage().contains('script returned exit code 143')) { - throw ae - } - echo 'Incremental compilation failed. Fall back to build from scratch' - sh ( - script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", - label: 'Clear old cmake workspace', - ) - cmake_build(docker_type, path, make_flag) - } + }, + ) } } def lint() { @@ -531,6 +568,29 @@ def add_hexagon_permissions() { ) } +// Run make. First try to do an incremental make from a previous workspace in hope to +// accelerate the compilation. If something is wrong, clean the workspace and then +// build from scratch. +def make(docker_type, path, make_flag) { + timeout(time: max_time, unit: 'MINUTES') { + try { + cmake_build(docker_type, path, make_flag) + } catch (hudson.AbortException ae) { + // script exited due to user abort, directly throw instead of retry + if (ae.getMessage().contains('script returned exit code 143')) { + throw ae + } + echo 'Incremental compilation failed. Fall back to build from scratch' + sh ( + script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", + label: 'Clear old cmake workspace', + ) + cmake_build(docker_type, path, make_flag) + } + } +} + + def build() { stage('Build') { environment { @@ -3239,6 +3299,25 @@ stage('Build packages') { } */ + +def update_docker(ecr_image, hub_image) { + if (!ecr_image.contains("amazonaws.com")) { + sh("echo Skipping '${ecr_image}' since it doesn't look like an ECR image") + return + } + docker_init(ecr_image) + sh( + script: """ + set -eux + docker tag \ + ${ecr_image} \ + ${hub_image} + docker push ${hub_image} + """, + label: "Update ${hub_image} on Docker Hub", + ) +} + def deploy_docs() { // Note: This code must stay in the Jenkinsfile to ensure that it runs // from a trusted context only @@ -3298,6 +3377,42 @@ def deploy() { } } } + if (env.BRANCH_NAME == 'main' && env.DEPLOY_DOCKER_IMAGES == 'yes' && rebuild_docker_images && upstream_revision != null) { + node('CPU') { + ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/deploy-docker") { + try { + withCredentials([string( + credentialsId: 'dockerhub-tlcpackstaging-key', + variable: 'DOCKERHUB_KEY', + )]) { + sh( + script: 'docker login -u tlcpackstaging -p ${DOCKERHUB_KEY}', + label: 'Log in to Docker Hub', + ) + } + def date_Ymd_HMS = sh( + script: 'python3 -c \'import datetime; print(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))\'', + label: 'Determine date', + returnStdout: true, + ).trim() + def tag = "${date_Ymd_HMS}-${upstream_revision.substring(0, 8)}" + update_docker(ci_arm, "tlcpackstaging/test_ci_arm:${tag}") + update_docker(ci_cpu, "tlcpackstaging/test_ci_cpu:${tag}") + update_docker(ci_gpu, "tlcpackstaging/test_ci_gpu:${tag}") + update_docker(ci_hexagon, "tlcpackstaging/test_ci_hexagon:${tag}") + update_docker(ci_i386, "tlcpackstaging/test_ci_i386:${tag}") + update_docker(ci_lint, "tlcpackstaging/test_ci_lint:${tag}") + update_docker(ci_qemu, "tlcpackstaging/test_ci_qemu:${tag}") + update_docker(ci_wasm, "tlcpackstaging/test_ci_wasm:${tag}") + } finally { + sh( + script: 'docker logout', + label: 'Clean up login credentials' + ) + } + } + } + } } } diff --git a/jenkins/Build.groovy.j2 b/jenkins/Build.groovy.j2 index 62ccc94916048..fcde53f559395 100644 --- a/jenkins/Build.groovy.j2 +++ b/jenkins/Build.groovy.j2 @@ -52,6 +52,29 @@ def add_hexagon_permissions() { {% endfor %} } +// Run make. First try to do an incremental make from a previous workspace in hope to +// accelerate the compilation. If something is wrong, clean the workspace and then +// build from scratch. +def make(docker_type, path, make_flag) { + timeout(time: max_time, unit: 'MINUTES') { + try { + cmake_build(docker_type, path, make_flag) + } catch (hudson.AbortException ae) { + // script exited due to user abort, directly throw instead of retry + if (ae.getMessage().contains('script returned exit code 143')) { + throw ae + } + echo 'Incremental compilation failed. Fall back to build from scratch' + sh ( + script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", + label: 'Clear old cmake workspace', + ) + cmake_build(docker_type, path, make_flag) + } + } +} + + def build() { stage('Build') { environment { diff --git a/jenkins/Deploy.groovy.j2 b/jenkins/Deploy.groovy.j2 index 917f71ded1ff3..3a049c5141dd9 100644 --- a/jenkins/Deploy.groovy.j2 +++ b/jenkins/Deploy.groovy.j2 @@ -16,6 +16,25 @@ stage('Build packages') { } */ + +def update_docker(ecr_image, hub_image) { + if (!ecr_image.contains("amazonaws.com")) { + sh("echo Skipping '${ecr_image}' since it doesn't look like an ECR image") + return + } + docker_init(ecr_image) + sh( + script: """ + set -eux + docker tag \ + ${ecr_image} \ + ${hub_image} + docker push ${hub_image} + """, + label: "Update ${hub_image} on Docker Hub", + ) +} + def deploy_docs() { // Note: This code must stay in the Jenkinsfile to ensure that it runs // from a trusted context only @@ -67,5 +86,36 @@ def deploy() { } } } + if (env.BRANCH_NAME == 'main' && env.DEPLOY_DOCKER_IMAGES == 'yes' && rebuild_docker_images && upstream_revision != null) { + node('CPU') { + ws({{ m.per_exec_ws('tvm/deploy-docker') }}) { + try { + withCredentials([string( + credentialsId: 'dockerhub-tlcpackstaging-key', + variable: 'DOCKERHUB_KEY', + )]) { + sh( + script: 'docker login -u tlcpackstaging -p ${DOCKERHUB_KEY}', + label: 'Log in to Docker Hub', + ) + } + def date_Ymd_HMS = sh( + script: 'python3 -c \'import datetime; print(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))\'', + label: 'Determine date', + returnStdout: true, + ).trim() + def tag = "${date_Ymd_HMS}-${upstream_revision.substring(0, 8)}" + {% for image in images %} + update_docker({{ image.name }}, "tlcpackstaging/test_{{ image.name }}:${tag}") + {% endfor %} + } finally { + sh( + script: 'docker logout', + label: 'Clean up login credentials' + ) + } + } + } + } } } diff --git a/jenkins/DockerBuild.groovy.j2 b/jenkins/DockerBuild.groovy.j2 index e9d80801a9d9c..a0ff666773f75 100644 --- a/jenkins/DockerBuild.groovy.j2 +++ b/jenkins/DockerBuild.groovy.j2 @@ -1,13 +1,47 @@ -def build_image(image_name) { - hash = sh( +def ecr_push(full_name) { + aws_account_id = sh( returnStdout: true, - script: 'git log -1 --format=\'%h\'' + script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', + label: 'Get AWS ID' ).trim() - def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" - sh( - script: "${docker_build} ${image_name} --spec ${full_name}", - label: 'Build docker image' - ) + + def ecr_name = "${aws_account_id}.{{ aws_ecr_url }}/${full_name}" + try { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION={{ aws_default_region }}', + "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -x + docker tag ${full_name} \$AWS_ECR_REPO/${full_name} + docker push \$AWS_ECR_REPO/${full_name} + """, + label: 'Upload image to ECR' + ) + } + } finally { + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION={{ aws_default_region }}', + "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } + } + return ecr_name +} + +def ecr_pull(full_name) { aws_account_id = sh( returnStdout: true, script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', @@ -15,152 +49,68 @@ def build_image(image_name) { ).trim() try { - // Use a credential so Jenkins knows to scrub the AWS account ID which is nice - // (but so we don't have to rely it being hardcoded in Jenkins) - withCredentials([string( - credentialsId: 'aws-account-id', - variable: '_ACCOUNT_ID_DO_NOT_USE', - )]) { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2']) { - sh( - script: ''' - set -x - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com - ''', - label: 'Log in to ECR' - ) - sh( - script: """ - set -x - docker tag ${full_name} \$AWS_ACCOUNT_ID.dkr.ecr.\$AWS_DEFAULT_REGION.amazonaws.com/${full_name} - docker push \$AWS_ACCOUNT_ID.dkr.ecr.\$AWS_DEFAULT_REGION.amazonaws.com/${full_name} - """, - label: 'Upload image to ECR' - ) - } + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION={{ aws_default_region }}', + "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { + sh( + script: ''' + set -eux + aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO + ''', + label: 'Log in to ECR' + ) + sh( + script: """ + set -eux + docker pull ${full_name} + """, + label: 'Pull image from ECR' + ) } } finally { - sh( - script: 'rm -f ~/.docker/config.json', - label: 'Clean up login credentials' - ) + withEnv([ + "AWS_ACCOUNT_ID=${aws_account_id}", + 'AWS_DEFAULT_REGION={{ aws_default_region }}', + "AWS_ECR_REPO=${aws_account_id}.{{ aws_ecr_url }}"]) { + sh( + script: 'docker logout $AWS_ECR_REPO', + label: 'Clean up login credentials' + ) + } } +} + +def build_image(image_name) { + hash = sh( + returnStdout: true, + script: 'git log -1 --format=\'%h\'' + ).trim() + def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" sh( - script: "docker rmi ${full_name}", - label: 'Remove docker image' + script: "${docker_build} ${image_name} --spec ${full_name}", + label: 'Build docker image' ) + return ecr_push(full_name) } + def build_docker_images() { stage('Docker Image Build') { - // TODO in a follow up PR: Find ecr tag and use in subsequent builds - parallel 'ci-lint': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_lint') - } - } - }, 'ci-cpu': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_cpu') + parallel( + {% for image in images %} + '{{ image.name }}': { + node('{{ image.platform }}') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + // We're purposefully not setting the built image here since they + // are not yet being uploaded to tlcpack + // {{ image.name }} = build_image('{{ image.name }}') + build_image('{{ image.name }}') + } } - } - }, 'ci-gpu': { - node('GPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_gpu') - } - } - }, 'ci-qemu': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_qemu') - } - } - }, 'ci-i386': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_i386') - } - } - }, 'ci-arm': { - node('ARM') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_arm') - } - } - }, 'ci-wasm': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_wasm') - } - } - }, 'ci-hexagon': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - docker_init('none') - init_git() - build_image('ci_hexagon') - } - } - } - } - // // TODO: Once we are able to use the built images, enable this step - // // If the docker images changed, we need to run the image build before the lint - // // can run since it requires a base docker image. Most of the time the images - // // aren't build though so it's faster to use the same node that checks for - // // docker changes to run the lint in the usual case. - // stage('Sanity Check (re-run)') { - // timeout(time: max_time, unit: 'MINUTES') { - // node('CPU') { - // ws({{ m.per_exec_ws('tvm/sanity') }}) { - // init_git() - // sh ( - // script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", - // label: 'Run lint', - // ) - // } - // } - // } - // } -} - -// Run make. First try to do an incremental make from a previous workspace in hope to -// accelerate the compilation. If something is wrong, clean the workspace and then -// build from scratch. -def make(docker_type, path, make_flag) { - timeout(time: max_time, unit: 'MINUTES') { - try { - cmake_build(docker_type, path, make_flag) - // always run cpp test when build - } catch (hudson.AbortException ae) { - // script exited due to user abort, directly throw instead of retry - if (ae.getMessage().contains('script returned exit code 143')) { - throw ae - } - echo 'Incremental compilation failed. Fall back to build from scratch' - sh ( - script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", - label: 'Clear old cmake workspace', - ) - cmake_build(docker_type, path, make_flag) - } + }, + {% endfor %} + ) } } diff --git a/jenkins/Jenkinsfile.j2 b/jenkins/Jenkinsfile.j2 index c165de964feb4..4e344c56d7f72 100644 --- a/jenkins/Jenkinsfile.j2 +++ b/jenkins/Jenkinsfile.j2 @@ -100,6 +100,9 @@ if (currentBuild.getBuildCauses().toString().contains('BranchIndexingCause')) { {% set hexagon_api = ['build/hexagon_api_output',] %} s3_prefix = "tvm-jenkins-artifacts-prod/tvm/${env.BRANCH_NAME}/${env.BUILD_NUMBER}" +{% set aws_default_region = "us-west-2" %} +{% set aws_ecr_url = "dkr.ecr." + aws_default_region + ".amazonaws.com" %} + // General note: Jenkins has limits on the size of a method (or top level code) // that are pretty strict, so most usage of groovy methods in these templates // are purely to satisfy the JVM diff --git a/jenkins/Lint.groovy.j2 b/jenkins/Lint.groovy.j2 index 40dad3aef7be3..3ede64301c935 100644 --- a/jenkins/Lint.groovy.j2 +++ b/jenkins/Lint.groovy.j2 @@ -2,11 +2,11 @@ def lint() { stage('Lint') { parallel( {% call m.sharded_lint_step( - name='Lint', - num_shards=2, - node='CPU-SMALL', - ws='tvm/lint', - docker_image='ci_lint', + name='Lint', + num_shards=2, + node='CPU-SMALL', + ws='tvm/lint', + docker_image='ci_lint', ) %} sh ( diff --git a/jenkins/Prepare.groovy.j2 b/jenkins/Prepare.groovy.j2 index 2900775f49452..894ddc72eeb7b 100644 --- a/jenkins/Prepare.groovy.j2 +++ b/jenkins/Prepare.groovy.j2 @@ -69,6 +69,17 @@ def docker_init(image) { """, label: 'Clean old Docker images', ) + + if (image.contains("amazonaws.com")) { + // If this string is in the image name it's from ECR and needs to be pulled + // with the right credentials + ecr_pull(image) + } else { + sh( + script: "docker pull ${image}", + label: 'Pull docker image', + ) + } } def should_skip_slow_tests(pr_number) { diff --git a/tests/python/ci/test_ci.py b/tests/python/ci/test_ci.py index 042c109dd9d49..7ef2f0cd58452 100644 --- a/tests/python/ci/test_ci.py +++ b/tests/python/ci/test_ci.py @@ -18,9 +18,11 @@ import subprocess import sys import json +from tempfile import tempdir import textwrap import pytest import tvm.testing +from pathlib import Path from test_utils import REPO_ROOT @@ -29,11 +31,13 @@ class TempGit: def __init__(self, cwd): self.cwd = cwd - def run(self, *args): - proc = subprocess.run(["git"] + list(args), cwd=self.cwd) + def run(self, *args, **kwargs): + proc = subprocess.run(["git"] + list(args), encoding="utf-8", cwd=self.cwd, **kwargs) if proc.returncode != 0: raise RuntimeError(f"git command failed: '{args}'") + return proc + def test_cc_reviewers(tmpdir_factory): reviewers_script = REPO_ROOT / "tests" / "scripts" / "github_cc_reviewers.py" @@ -747,5 +751,94 @@ def run(type, data, check): ) +@pytest.mark.parametrize( + "changed_files,name,check,expected_code", + [ + d.values() + for d in [ + dict( + changed_files=[], + name="abc", + check="Image abc is not using new naming scheme", + expected_code=1, + ), + dict( + changed_files=[], name="123-123-abc", check="No extant hash found", expected_code=1 + ), + dict( + changed_files=[["test.txt"]], + name=None, + check="Did not find changes, no rebuild necessary", + expected_code=0, + ), + dict( + changed_files=[["test.txt"], ["docker/test.txt"]], + name=None, + check="Found docker changes", + expected_code=2, + ), + ] + ], +) +def test_should_rebuild_docker(tmpdir_factory, changed_files, name, check, expected_code): + tag_script = REPO_ROOT / "tests" / "scripts" / "should_rebuild_docker.py" + + git = TempGit(tmpdir_factory.mktemp("tmp_git_dir")) + git.run("init") + git.run("config", "user.name", "ci") + git.run("config", "user.email", "email@example.com") + git.run("checkout", "-b", "main") + git.run("remote", "add", "origin", "https://github.com/apache/tvm.git") + + git_path = Path(git.cwd) + for i, commits in enumerate(changed_files): + for filename in commits: + path = git_path / filename + path.parent.mkdir(exist_ok=True, parents=True) + path.touch() + git.run("add", filename) + + git.run("commit", "-m", f"message {i}") + + if name is None: + ref = "HEAD" + if len(changed_files) > 1: + ref = f"HEAD~{len(changed_files) - 1}" + proc = git.run("rev-parse", ref, stdout=subprocess.PIPE) + last_hash = proc.stdout.strip() + name = f"123-123-{last_hash}" + + docker_data = { + "repositories/tlcpack": { + "results": [ + { + "name": "ci-something", + }, + { + "name": "something-else", + }, + ], + }, + "repositories/tlcpack/ci-something/tags": { + "results": [{"name": name}, {"name": name + "old"}], + }, + } + + proc = subprocess.run( + [ + str(tag_script), + "--testing-docker-data", + json.dumps(docker_data), + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding="utf-8", + cwd=git.cwd, + ) + + assert_in(check, proc.stdout) + assert proc.returncode == expected_code + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/cmd_utils.py b/tests/scripts/cmd_utils.py index 272086796e8df..771c3ee52dbd2 100644 --- a/tests/scripts/cmd_utils.py +++ b/tests/scripts/cmd_utils.py @@ -44,18 +44,21 @@ def init_log(): class Sh: - def __init__(self, env=None): + def __init__(self, env=None, cwd=None): self.env = os.environ.copy() if env is not None: self.env.update(env) + self.cwd = cwd def run(self, cmd: str, **kwargs): logging.info(f"+ {cmd}") - if "check" not in kwargs: - kwargs["check"] = True - if "shell" not in kwargs: - kwargs["shell"] = True - if "env" not in kwargs: - kwargs["env"] = self.env - - subprocess.run(cmd, **kwargs) + defaults = { + "check": True, + "shell": True, + "env": self.env, + "encoding": "utf-8", + "cwd": self.cwd, + } + defaults.update(kwargs) + + return subprocess.run(cmd, **defaults) diff --git a/tests/scripts/git_utils.py b/tests/scripts/git_utils.py index 267756d859050..c5ea8d85e0718 100644 --- a/tests/scripts/git_utils.py +++ b/tests/scripts/git_utils.py @@ -20,6 +20,7 @@ import subprocess import re import base64 +import logging from urllib import request from typing import Dict, Tuple, Any, Optional, List diff --git a/tests/scripts/http_utils.py b/tests/scripts/http_utils.py new file mode 100644 index 0000000000000..c14259479d3be --- /dev/null +++ b/tests/scripts/http_utils.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import logging +from urllib import request +from typing import Dict, Any, Optional + + +def get(url: str, headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + logging.info(f"Requesting GET to {url}") + if headers is None: + headers = {} + req = request.Request(url, headers=headers) + with request.urlopen(req) as response: + response_headers = {k: v for k, v in response.getheaders()} + response = json.loads(response.read()) + + return response, response_headers diff --git a/tests/scripts/should_rebuild_docker.py b/tests/scripts/should_rebuild_docker.py new file mode 100755 index 0000000000000..dc12c38de8303 --- /dev/null +++ b/tests/scripts/should_rebuild_docker.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import argparse +import datetime +import json +import logging +import subprocess + +from typing import Dict, Any, List + + +from http_utils import get +from cmd_utils import Sh, init_log + + +DOCKER_API_BASE = "https://hub.docker.com/v2/" +PAGE_SIZE = 25 +TEST_DATA = None + + +def docker_api(url: str) -> Dict[str, Any]: + """ + Run a paginated fetch from the public Docker Hub API + """ + if TEST_DATA is not None: + return TEST_DATA[url] + pagination = f"?page_size={PAGE_SIZE}&page=1" + url = DOCKER_API_BASE + url + pagination + r, headers = get(url) + reset = headers.get("x-ratelimit-reset") + if reset is not None: + reset = datetime.datetime.fromtimestamp(int(reset)) + reset = reset.isoformat() + logging.info( + f"Docker API Rate Limit: {headers.get('x-ratelimit-remaining')} / {headers.get('x-ratelimit-limit')} (reset at {reset})" + ) + if "results" not in r: + raise RuntimeError(f"Error fetching data, no results found in: {r}") + return r + + +def any_docker_changes_since(hash: str) -> bool: + """ + Check the docker/ directory, return True if there have been any code changes + since the specified hash + """ + sh = Sh() + cmd = f"git diff {hash} -- docker/" + proc = sh.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout = proc.stdout.strip() + return stdout != "", stdout + + +def does_commit_exist(hash: str) -> bool: + """ + Returns True if the hash exists in the repo + """ + sh = Sh() + cmd = f"git rev-parse -q {hash}" + proc = sh.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False) + print(proc.stdout) + if proc.returncode == 0: + return True + + if "unknown revision or path not in the working tree" in proc.stdout: + return False + + raise RuntimeError(f"Unexpected failure when running: {cmd}") + + +def find_hash_for_tag(tag: Dict[str, Any]) -> str: + """ + Split the hash off of a name like -