From 7a64997d5dafce75ec9875b601b4c648e45f8cc8 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 21 Jun 2022 06:39:32 -0700 Subject: [PATCH 01/27] optimize_torch & as_torch --- apps/pt_tvmdsoop/tests/test_tuning_relay.py | 55 +++++ .../tests/test_tvmscript_torch_module.py | 124 ++++++++++ gallery/how_to/work_with_pytorch/as_torch.py | 130 +++++++++++ .../work_with_pytorch/optimize_torch.py | 210 +++++++++++++++++ python/tvm/contrib/torch/__init__.py | 10 +- python/tvm/contrib/torch/script_torch.py | 197 ++++++++++++++++ src/contrib/torch/base64.h | 76 +++++++ .../pt_call_tvm/ExecutorFactoryWrapper.cc | 214 ++++++++++++++++++ 8 files changed, 1014 insertions(+), 2 deletions(-) create mode 100644 apps/pt_tvmdsoop/tests/test_tuning_relay.py create mode 100644 apps/pt_tvmdsoop/tests/test_tvmscript_torch_module.py create mode 100644 gallery/how_to/work_with_pytorch/as_torch.py create mode 100644 gallery/how_to/work_with_pytorch/optimize_torch.py create mode 100644 python/tvm/contrib/torch/script_torch.py create mode 100644 src/contrib/torch/base64.h create mode 100644 src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc diff --git a/apps/pt_tvmdsoop/tests/test_tuning_relay.py b/apps/pt_tvmdsoop/tests/test_tuning_relay.py new file mode 100644 index 000000000000..206e75400545 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_tuning_relay.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import tvm +import torch +from tvm.contrib.torch import optimize_torch +from tvm.meta_schedule import TuneConfig +import tvm.testing + + + +def matmul(x, w): + return torch.matmul(x, w) + + +def test_matmul_tuning_relay(): + config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=4, + max_trials_per_task=4, + max_trials_global=4, + search_strategy_config={ + "genetic_num_iters": 10, + }, + ) + x = torch.randn(15, 20) + w = torch.randn(20, 30) + example_inputs = (x, w) + + rt_mod = optimize_torch(matmul, example_inputs, config) + + torch_answer = torch.matmul(x, w).numpy() + tvm_answer = rt_mod(x, w).numpy() + + tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5) + +if __name__ == "__main__": + test_matmul_tuning_relay() + \ No newline at end of file diff --git a/apps/pt_tvmdsoop/tests/test_tvmscript_torch_module.py b/apps/pt_tvmdsoop/tests/test_tvmscript_torch_module.py new file mode 100644 index 000000000000..4cda9be4fe00 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_tvmscript_torch_module.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import tvm +import torch +from tvm.contrib.torch import as_torch +from tvm.script import tir as T +import numpy as np +import torch.nn +import tvm.testing + + +@as_torch +def matmul(M: int, N: int, K: int, dtype: str): + @T.prim_func + def f(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + return f + + +@as_torch +@tvm.script.ir_module +class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 + + +class MinuesOnes(torch.nn.Module): + def __init__(self): + super(MinuesOnes, self).__init__() + self.engine = MyModule + + def forward(self, *input): + self.engine.forward(*input) + return input[-1] - 1 + + +def test_tvmscript_torch_matmul(): + s1 = np.ones((128, 128)).astype("float32") + s2 = np.ones((128, 128)).astype("float32") + s3 = np.zeros((128, 128)).astype("float32") + s1[0, 0] = 0 + s2[4, 4] = 0 + + q1 = torch.from_numpy(s1) + q2 = torch.from_numpy(s2) + q3 = torch.from_numpy(s3) + + numpy_result = np.matmul(s1, s2) + + tvm_module = matmul(128, 128, 128, "float32") + + tvm_module(q1, q2, q3) + + tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_decorator(): + s1 = np.arange(8).astype("float32") + + q1 = torch.arange(8).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32) + + numpy_result = s1 + 1 + + tvm_module = MyModule + + tvm_module(q1, q2) + + tvm.testing.assert_allclose(q2.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + + +def test_torch_with_tvmscirpt(): + s1 = np.arange(8).astype("float32") + + q1 = torch.arange(8).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32) + + tvm_module = MinuesOnes() + + ret = tvm_module.forward(q1, q2) + + tvm.testing.assert_allclose(ret.numpy(), s1, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + test_tvmscript_torch_matmul() + test_tvmscript_torch_decorator() + test_torch_with_tvmscirpt() diff --git a/gallery/how_to/work_with_pytorch/as_torch.py b/gallery/how_to/work_with_pytorch/as_torch.py new file mode 100644 index 000000000000..0d88ada429a6 --- /dev/null +++ b/gallery/how_to/work_with_pytorch/as_torch.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Wrap Your Tensor IR with PyTorch Module +====================== +**Author**: `Yaoda Zhou `_ +This article is an introductory tutorial to wrap the Tensor IR code with PyTorch module. +By the decorator `as_torch`, users are able to import a Tensor IR code in PyTorch with a low cost. +For us to follow this tutorial, PyTorch as well as TorchVision should be installed. +For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as +.. code-block:: bash + conda install -c conda-forge pytorch-gpu +""" +# Import Tvm and PyTorch, as well as necessary libraries +import tvm +import torch +from tvm.contrib.torch import as_torch +from tvm.script import tir as T +import numpy as np +import torch.nn +import tvm.testing + + +###################################################################### +# Define an example of vector add +# (This example could be found at https://tvm.apache.org/docs/tutorial/tensor_ir_blitz_course.html) +# ------------------------------- +# Our `as_torch` is a simple decorator: put it on any Tensor IR code and it will convert it into PyTorch module automatically. +@as_torch +@tvm.script.ir_module +class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + with T.block("B"): + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 + +###################################################################### +# Write a test case: Tvm's testing is used to compare two tensors +# ------------------------------- + + +def test_tvmscript_torch_decorator(): + s1 = np.arange(8).astype("float32") + + # Define two torch tensors + q1 = torch.arange(8).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32) + + # Result from numpy + numpy_result = s1 + 1 + + tvm_module = MyModule + + # We call `MyModule` as PyTorch module's forward + tvm_module(q1, q2) + + # Testing. No output implies that tensors are equal + tvm.testing.assert_allclose(q2.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + + +test_tvmscript_torch_decorator() + +###################################################################### +# Another example: matrix multiplication with a limit form meta-programming +# ------------------------------- +# As above, we can add `as_torch` decorator to a Tensor IR function. + + +@as_torch +def matmul(M: int, N: int, K: int, dtype: str): + @T.prim_func + def f(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + return f + +###################################################################### +# Test case for `matmul` function. +# ------------------------------- + + +def test_tvmscript_torch_matmul(): + # Create two 128 x 128 matrixs as input + s1 = np.random.rand(128, 128).astype("float32") + s2 = np.random.rand(128, 128).astype("float32") + s3 = np.zeros((128, 128)).astype("float32") + + q1 = torch.from_numpy(s1) + q2 = torch.from_numpy(s2) + q3 = torch.from_numpy(s3) + + # Result from numpy + numpy_result = np.matmul(s1, np.transpose(s2)) + + # Instantiate the `matmul` function by passing the parameters of shapes and datatype + tvm_module = matmul(128, 128, 128, "float32") + + # Run the operator + tvm_module(q1, q2, q3) + + tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + + +test_tvmscript_torch_matmul() diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py new file mode 100644 index 000000000000..7e8208b96ea6 --- /dev/null +++ b/gallery/how_to/work_with_pytorch/optimize_torch.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile PyTorch Models +====================== +**Author**: `Yaoda Zhou `_ +This article is an introductory tutorial to optimize PyTorch models by MetaSchedule. +For us to follow this tutorial, PyTorch as well as TorchVision should be installed. +For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as +.. code-block:: bash + conda install -c conda-forge pytorch-gpu +""" +# Import TVM and PyTorch +import tvm +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import resnet18 +import tvm.testing + +# Import `optimize_torch` function +from tvm.contrib.torch import optimize_torch +from tvm.meta_schedule import TuneConfig + +# Import library for profiling +import torch.utils.benchmark as benchmark + + +###################################################################### +# Define a simple module written by PyTorch +# ------------------------------ + + +class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + +# ##################################################################### +# Optimized SimpleModel by TorchScript +# ------------------------------ +# We trace this module, then run the `optimize_for_inference` function from TorchScript +simple_model = SimpleModel() +example_input = torch.randn(20, 1, 10, 10) +model_optimized_by_jit = torch.jit.optimize_for_inference( + torch.jit.trace(simple_model, example_input)) + +###################################################################### +# Optimized SimpleModel by TVM MetaSchedule +# ------------------------------ +# We provide a `optimize_torch` function, which have a similar usage as `torch.jit.trace`. +# For the function, we have five parameters need to provide. +# If the third parameter `tuning_config` is not provided, a default configuration is loaded. +# If the parameter `target` is empty, the model will deploy on CPU. + +tuning_config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=64, + max_trials_per_task=2000, + max_trials_global=2000, +) + +# We use default configuration for the first example +model_optimized_by_meta = optimize_torch( + SimpleModel(), example_input) + +###################################################################### +# Compare the performance between two scheduling approaches. +# ------------------------------ +# Using PyTorch's benchmark Compare class, we can have a straightforward comparison between two inference models. + +results = [] +for i in range(20): + test_input = torch.rand(20, 1, 10, 10) + sub_label = f'[test {i}]' + results.append(benchmark.Timer( + stmt='model_optimized_by_meta(test_input)', + setup='from __main__ import model_optimized_by_meta', + globals={'test_input': test_input}, + sub_label=sub_label, + description='tuning by meta', + ).blocked_autorange()) + results.append(benchmark.Timer( + stmt='model_optimized_by_jit(test_input)', + setup='from __main__ import model_optimized_by_jit', + globals={'test_input': test_input}, + sub_label=sub_label, + description='tuning by jit', + ).blocked_autorange()) + +# We can print the results on screen. +compare = benchmark.Compare(results) +compare.print() + +###################################################################### +# Save/Load module +# ------------------------------ +# We can save and load our tuned module like standard nn.module + +# Let us run our tuned module and see the result +ret1 = model_optimized_by_meta(example_input) + +torch.save(model_optimized_by_meta, "meta_model.pt") +model_loaded = torch.load("meta_model.pt") + +# We load the module and run again and it will return the same result as above. +ret2 = model_loaded(example_input) + +tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) + +###################################################################### +# Define the resnet18 optimized by MetaSchedule +# ------------------------------ +# Another example, we compare the two optimizers about the performance of resnet18 +# For learning how to define a resnet18 model via PyTorch's nn.Module, +# you can refer to https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting + +# In our working machine, the GPU model is nvidia/geforce-rtx-3070. +target_cuda = "nvidia/geforce-rtx-3070" + +# For PyTorch users, you can write your nn.Module in a normal way. +# By applying "optimize_torch" function on the resnet18 model, we obtain a new resnet18 model optimized by MetaSchedule + + +class MyResNet18(torch.nn.Module): + def __init__(self, config, target=None): + super(MyResNet18, self).__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)).cuda() + self.resnet = optimize_torch( + resnet18(), [torch.rand(1, 3, 224, 224)], config, target) + + def forward(self, input): + return self.resnet(input - self.means) + + +# Since the setting of the number of trials is large, the initialization could be slow (sometimes more than 3 hours!) +meta_module_resnet18 = MyResNet18(tuning_config, target_cuda) + + +###################################################################### +# Define the resnet18 optimized by TorchScript +# ------------------------------ +# Besides, let us define a resnet18 model in a standard way. +# TorchScript also provide a built-in "optimize_for_inference" function to accelerate the inference. + +class JitModule(torch.nn.Module): + def __init__(self): + super(JitModule, self).__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)).cuda() + self.resnet = torch.jit.optimize_for_inference( + torch.jit.script(resnet18().cuda().eval())) + + def forward(self, input): + return self.resnet(input - self.means) + + +jit_module_resnet18 = JitModule() + +###################################################################### +# Compare the performance between two scheduling approaches. +# ------------------------------ +# Using PyTorch's benchmark Compare class, we can have a straightforward comparison between two inference models. + +results = [] +for i in range(20): + test_input = torch.rand(1, 3, 224, 224).half().cuda() + sub_label = f'[test {i}]' + results.append(benchmark.Timer( + stmt='meta_module_resnet18(test_input)', + setup='from __main__ import meta_module_resnet18', + globals={'test_input': test_input}, + sub_label=sub_label, + description='tuning by meta', + ).blocked_autorange()) + results.append(benchmark.Timer( + stmt='jit_module_resnet18(test_input)', + setup='from __main__ import jit_module_resnet18', + globals={'test_input': test_input}, + sub_label=sub_label, + description='tuning by jit', + ).blocked_autorange()) + +# We can print the results on screen. +compare = benchmark.Compare(results) +compare.print() + +# As above, we can save the module for future use +torch.save(meta_module_resnet18, "meta_tuned_resnet18.pt") diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 720ac29cc6e2..ce84ecd9a1ad 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -20,8 +20,6 @@ import platform import torch from tvm._ffi import libinfo -from tvm.relay.frontend import pytorch - def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): system = platform.system() @@ -39,6 +37,7 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): _load_platform_specific_library() + from . import module GraphModule = module.GraphModule @@ -49,3 +48,10 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule compile = pytorch_tvm.compile + +from . import script_torch + +as_torch = script_torch.as_torch +GraphExecutorFactoryWrapper = script_torch.GraphExecutorFactoryWrapper +TVMScriptIRModule = script_torch.TVMScriptIRModule +optimize_torch = script_torch.optimize_torch \ No newline at end of file diff --git a/python/tvm/contrib/torch/script_torch.py b/python/tvm/contrib/torch/script_torch.py new file mode 100644 index 000000000000..c1e811c802c6 --- /dev/null +++ b/python/tvm/contrib/torch/script_torch.py @@ -0,0 +1,197 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import functools +import tempfile +from typing import Callable, Dict, Tuple, Union, List + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig + + +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__( + self, + module: tvm.runtime.Module + ): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + else: + return ret + + +class TVMScriptIRModule(torch.nn.Module): + def __init__(self, module: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, tvm.contrib.graph_executor.GraphModule]): + super().__init__() + self.engine_cpu = None + self.engine_cuda = None + self.ir_module = module + + def __save_cpu_rt_module(self, runtime_module): + func = tvm.get_global_func("tvmtorch.save_runtime_mod") + func(runtime_module) + + self.engine_cpu = torch.classes.tvm_torch.TVMScriptRuntime() + + def __build_cpu(self): + runtime_module = tvm.build(self.ir_module) + self.__save_cpu_rt_module(runtime_module) + + def __save_cuda_rt_module(self, runtime_module): + self.engine_cuda = runtime_module + + def __build_cuda(self): + runtime_module = tvm.build(self.ir_module, target=tvm.target.cuda()) + self.__save_cuda_rt_module(runtime_module) + + def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: + if torch_inputs[0].is_cuda: + if self.engine_cuda is None: + self.__build_cuda() + return self.engine_cuda.forward(torch_inputs) + else: + if self.engine_cpu is None: + self.__build_cpu() + return self.engine_cpu.forward(torch_inputs) + + +def as_torch( + func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] +): + """A decorator of converting TensorIR to PyTorch nn.Module. + + Parameters + ---------- + func : Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] + The function to be parsed. + + + Returns + ------- + mod : TVMScriptIRModule + It will return an object of TVMScriptIRModule, which is the subclass of the original nn.Module. + """ + if isinstance(func, tvm.ir.module.IRModule) or isinstance(func, tvm.tir.function.PrimFunc): + return TVMScriptIRModule(func) + elif isinstance(func, Callable): + def func_get_param(*args, **kargs): + return TVMScriptIRModule(func(*args, **kargs)) + return func_get_param + + +@functools.lru_cache(None) +def llvm_target(): + return "llvm -num-cores" + + +def tuning_relay(mod: tvm.ir.module.IRModule, params: Dict, config: TuneConfig, target, work_dir: str = None): + from tvm.meta_schedule.tune import tune_relay + with tempfile.TemporaryDirectory() as tmp_work_dir: + return tune_relay( + mod=mod, + params=params, + target=target, + config=config, + work_dir=work_dir if work_dir else tmp_work_dir, + ) + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration of tuning by MetaSchedule. + + target : Optional[Union[str, Target]] + The target of the compilation. + If user doesn't set the target, the module is built upon the LLVM. + + work_dir : Optional[str] + The working directory to save intermediate results. + + Returns + ------- + mod : GraphExecutorFactoryWrapper + It will return an object of GraphExecutorFactoryWrapper, which is the subclass of the original nn.Module. + """ + + if target: + pass + else: + target = llvm_target() + + if tuning_config: + pass + else: + # Default setting. For a better tuning result the number could be set larger. + tuning_config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=4, + max_trials_per_task=16, + max_trials_global=16, + ) + + jit_mod = torch.jit.trace(func, example_inputs) + if isinstance(example_inputs, torch.Tensor): + example_inputs = [example_inputs] + shape_list = [(f"inp_{idx}", i.shape) + for idx, i in enumerate(example_inputs)] + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule + executor_factory = tuning_relay( + mod, params, tuning_config, target, work_dir) + + save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") + save_runtime_mod(executor_factory.module) + + return GraphExecutorFactoryWrapper(torch.classes.tvm_tuning.RelayRuntime()) diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h new file mode 100644 index 000000000000..d949db572c52 --- /dev/null +++ b/src/contrib/torch/base64.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file base64.h + * \brief Util functions for converting plain bytes back to plain bytes + */ + +#ifndef TVM_SUPPORT_BASE64_RT_H_ +#define TVM_SUPPORT_BASE64_RT_H_ + +#include + +#include +#include +#include + +#include "../../support/base64.h" + +namespace tvm { +namespace support { + +size_t b64strlen(std::string& b64str) { + ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; + size_t length = b64str.size() / 4 * 3; + if (b64str[b64str.size() - 2] == '=') { + length -= 2; + } else if (b64str[b64str.size() - 1] == '=') { + length -= 1; + } + return length; +} + +void b64decode(std::string& b64str, u_char* ret) { + size_t index = 0; + const auto length = b64str.size(); + for (size_t i = 0; i < length; i += 4) { + int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]]; + int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]]; + int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]]; + int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]]; + u_char st1 = (ch0 << 2) + (ch1 >> 4); + ret[index++] = st1; + if (b64str[i + 2] != '=') { + u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2); + ret[index++] = st2; + if (b64str[i + 3] != '=') { + u_char st3 = ((ch2 & 0b11) << 6) + ch3; + ret[index++] = st3; + } + } + } + ret[index] = '\0'; + ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; +} + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_BASE64_RT_H_ \ No newline at end of file diff --git a/src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc b/src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc new file mode 100644 index 000000000000..9adc771c1840 --- /dev/null +++ b/src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../base64.h" + +namespace tvm { +namespace contrib { + +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +class TVMScriptRuntimeClass : public torch::jit::CustomClassHolder { + public: + TVMScriptRuntimeClass() { mod_ = ThreadLocalStore::ThreadLocal()->mod; } + + void forward(const c10::List& inputs) { + int input_length = inputs.size(); + + std::vector tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + tvm::runtime::PackedFunc run = mod_.GetFunction("__tvm_main__"); + + std::vector tvm_values(input_length); + std::vector tvm_type_codes(input_length); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + for (int k = 0; k < input_length; ++k) { + setter(k, &tensors[k]->dl_tensor); + } + + run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length), + nullptr); + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + } + + private: + tvm::runtime::Module mod_; +}; + +tvm::Device getDevice(const at::Tensor& tensor) { + tvm::Device dev; + dev.device_id = tensor.get_device(); + switch (tensor.device().type()) { + case at::DeviceType::CPU: + dev.device_type = DLDeviceType::kDLCPU; + if (dev.device_id == -1) { + dev.device_id = 1; + } + break; + case at::DeviceType::CUDA: + dev.device_type = DLDeviceType::kDLCUDA; + break; + default: + TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str()); + } + return dev; +} + +class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { + public: + GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) + : executor_factory_(executor_factory) {} + + GraphExecutorFactoryWrapper() + : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {} + + c10::List forward(const c10::List& inputs) { + int input_length = inputs.size(); + + if (!executor_.defined()) { + TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); + DLDevice input_device = getDevice(inputs.get(0)); + + auto tmp = executor_factory_.GetFunction("default"); + + executor_ = tmp(input_device); + } + + std::vector tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + tvm::runtime::PackedFunc run = executor_.GetFunction("run"); + tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input"); + tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output"); + tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs"); + + for (int k = 0; k < input_length; ++k) { + set_input(k, &tensors[k]->dl_tensor); + } + + run(); + + int64_t output_length = get_num_outputs(); + + c10::List outputs; + outputs.reserve(output_length); + + for (int k = 0; k < output_length; ++k) { + tvm::runtime::NDArray results = get_output(k); + at::Tensor atTensor = at::fromDLPack(results.ToDLPack()); + outputs.emplace_back(atTensor); + } + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + return outputs; + } + + using SerializationType = std::string; // executor factory stream + + SerializationType Serialize() { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(executor_factory_); + } + + GraphExecutorFactoryWrapper(SerializationType state) { + auto length = tvm::support::b64strlen(state); + + u_char bytes[length]; + memset(bytes, 0, sizeof(bytes)); + tvm::support::b64decode(state, bytes); + + const char* name = tmpnam(NULL); + auto file_name = std::string(name) + ".so"; + auto pFile = fopen(file_name.c_str(), "wb"); + fwrite(bytes, sizeof(u_char), length, pFile); + fclose(pFile); + + std::string load_f_name = "runtime.module.loadfile_so"; + const PackedFunc* f = runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + + executor_factory_ = (*f)(file_name, ""); + + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + } + + private: + tvm::runtime::Module executor_factory_; + tvm::runtime::Module executor_; +}; + +TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { + ThreadLocalStore::ThreadLocal()->mod = mod; +}); + +TORCH_LIBRARY(tvm_torch, m) { + m.class_("TVMScriptRuntime") + .def(torch::init<>()) + .def("forward", &TVMScriptRuntimeClass::forward); +} + +TORCH_LIBRARY(tvm_tuning, m) { + m.class_("RelayRuntime") + .def(torch::init<>()) + .def("forward", &GraphExecutorFactoryWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) + -> GraphExecutorFactoryWrapper::SerializationType { return self->Serialize(); }, + [](GraphExecutorFactoryWrapper::SerializationType state) { + return c10::make_intrusive(state); + }); +} + +} // namespace contrib +} // namespace tvm \ No newline at end of file From d12eb54cd44058eed29cc490856052b7f3be1112 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 22 Jun 2022 01:01:13 -0700 Subject: [PATCH 02/27] split files --- ...cript_torch_module.py => test_as_torch.py} | 0 apps/pt_tvmdsoop/tests/test_optimize_torch.py | 148 ++++++++++++++++++ apps/pt_tvmdsoop/tests/test_tuning_relay.py | 55 ------- .../work_with_pytorch/optimize_torch.py | 123 +++------------ python/tvm/contrib/torch/__init__.py | 13 +- python/tvm/contrib/torch/as_torch.py | 84 ++++++++++ .../{script_torch.py => optimize_torch.py} | 93 ++--------- 7 files changed, 275 insertions(+), 241 deletions(-) rename apps/pt_tvmdsoop/tests/{test_tvmscript_torch_module.py => test_as_torch.py} (100%) create mode 100644 apps/pt_tvmdsoop/tests/test_optimize_torch.py delete mode 100644 apps/pt_tvmdsoop/tests/test_tuning_relay.py create mode 100644 python/tvm/contrib/torch/as_torch.py rename python/tvm/contrib/torch/{script_torch.py => optimize_torch.py} (57%) diff --git a/apps/pt_tvmdsoop/tests/test_tvmscript_torch_module.py b/apps/pt_tvmdsoop/tests/test_as_torch.py similarity index 100% rename from apps/pt_tvmdsoop/tests/test_tvmscript_torch_module.py rename to apps/pt_tvmdsoop/tests/test_as_torch.py diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py new file mode 100644 index 000000000000..db58cd25c926 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -0,0 +1,148 @@ +# pylint: disable=missing-class-docstring +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import tempfile + +import torch +import torch.utils.benchmark as benchmark +from torchvision.models import resnet18 + +import tvm +import tvm.testing +from tvm.contrib.torch import optimize_torch +from tvm.meta_schedule import TuneConfig + +# default config for testing +config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=64, + max_trials_per_task=2000, + max_trials_global=2000, + ) + +def test_matmul_tuning_relay(): + def matmul(x, w): + return torch.matmul(x, w) + + x = torch.randn(15, 20) + w = torch.randn(20, 30) + example_inputs = (x, w) + + rt_mod = optimize_torch(matmul, example_inputs, config) + + torch_answer = torch.matmul(x, w).numpy() + tvm_answer = rt_mod(x, w).numpy() + + tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5) + +class InnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 20, 5) + + def forward(self, x): + return torch.nn.functional.relu(self.conv(x)) + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(20, 20, 5) + self.relu = InnerModel() + + def forward(self, x): + x = self.relu(x) + return torch.nn.functional.relu(self.conv(x)) + +def test_nested_module(): + simple_module = SimpleModel() + example_input = torch.randn(20, 1, 10, 10) + optimized_module = optimize_torch( + simple_module, example_input) + ret1 = simple_module(example_input).detach().numpy() + ret2 = optimized_module(example_input).detach().numpy() + tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) + +def test_save_load_function(): + def foo(x): + return 2 * x + 1 + + example_input = torch.rand(3) + opt_foo = optimize_torch(foo, example_input) + ret1 = opt_foo(example_input) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + torch.save(opt_foo, tmp.name) + loaded_mod = torch.load(tmp.name) + ret2 = loaded_mod(example_input) + tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) + +class MyResNet18(torch.nn.Module): + def __init__(self, config, target=None): + super(MyResNet18, self).__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)).cuda() + self.resnet = optimize_torch( + resnet18(), [torch.rand(1, 3, 224, 224)], config, target) + + def forward(self, input): + return self.resnet(input - self.means) + +class JitModule(torch.nn.Module): + def __init__(self): + super(JitModule, self).__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)).cuda() + self.resnet = torch.jit.optimize_for_inference( + torch.jit.script(resnet18().cuda().eval())) + + def forward(self, input): + return self.resnet(input - self.means) + +target_cuda = "nvidia/geforce-rtx-3070" +meta_module_resnet18 = MyResNet18(config, target_cuda) +jit_module_resnet18 = JitModule() + +def compare_optimize_resnet18_to_torchscript(): + results = [] + for i in range(20): + test_input = torch.rand(1, 3, 224, 224).half().cuda() + sub_label = f'[test {i}]' + results.append(benchmark.Timer( + stmt='meta_module_resnet18(test_input)', + setup='from __main__ import meta_module_resnet18', + globals={'test_input': test_input}, + sub_label=sub_label, + description='tuning by meta', + ).blocked_autorange()) + results.append(benchmark.Timer( + stmt='jit_module_resnet18(test_input)', + setup='from __main__ import jit_module_resnet18', + globals={'test_input': test_input}, + sub_label=sub_label, + description='tuning by jit', + ).blocked_autorange()) + compare = benchmark.Compare(results) + compare.print() + +if __name__ == "__main__": + test_matmul_tuning_relay() + test_nested_module() + test_save_load_function() + compare_optimize_resnet18_to_torchscript() + \ No newline at end of file diff --git a/apps/pt_tvmdsoop/tests/test_tuning_relay.py b/apps/pt_tvmdsoop/tests/test_tuning_relay.py deleted file mode 100644 index 206e75400545..000000000000 --- a/apps/pt_tvmdsoop/tests/test_tuning_relay.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test script for tvm torch module""" -import tvm -import torch -from tvm.contrib.torch import optimize_torch -from tvm.meta_schedule import TuneConfig -import tvm.testing - - - -def matmul(x, w): - return torch.matmul(x, w) - - -def test_matmul_tuning_relay(): - config = TuneConfig( - strategy="evolutionary", - num_trials_per_iter=4, - max_trials_per_task=4, - max_trials_global=4, - search_strategy_config={ - "genetic_num_iters": 10, - }, - ) - x = torch.randn(15, 20) - w = torch.randn(20, 30) - example_inputs = (x, w) - - rt_mod = optimize_torch(matmul, example_inputs, config) - - torch_answer = torch.matmul(x, w).numpy() - tvm_answer = rt_mod(x, w).numpy() - - tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5) - -if __name__ == "__main__": - test_matmul_tuning_relay() - \ No newline at end of file diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py index 7e8208b96ea6..0d759668de2b 100644 --- a/gallery/how_to/work_with_pytorch/optimize_torch.py +++ b/gallery/how_to/work_with_pytorch/optimize_torch.py @@ -24,18 +24,17 @@ .. code-block:: bash conda install -c conda-forge pytorch-gpu """ -# Import TVM and PyTorch +# Import TVM import tvm -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision.models import resnet18 import tvm.testing - # Import `optimize_torch` function from tvm.contrib.torch import optimize_torch from tvm.meta_schedule import TuneConfig +# Import PyTorch +import torch +import torch.nn as nn +import torch.nn.functional as F # Import library for profiling import torch.utils.benchmark as benchmark @@ -44,7 +43,6 @@ # Define a simple module written by PyTorch # ------------------------------ - class SimpleModel(nn.Module): def __init__(self): super().__init__() @@ -55,16 +53,6 @@ def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) - -# ##################################################################### -# Optimized SimpleModel by TorchScript -# ------------------------------ -# We trace this module, then run the `optimize_for_inference` function from TorchScript -simple_model = SimpleModel() -example_input = torch.randn(20, 1, 10, 10) -model_optimized_by_jit = torch.jit.optimize_for_inference( - torch.jit.trace(simple_model, example_input)) - ###################################################################### # Optimized SimpleModel by TVM MetaSchedule # ------------------------------ @@ -73,16 +61,28 @@ def forward(self, x): # If the third parameter `tuning_config` is not provided, a default configuration is loaded. # If the parameter `target` is empty, the model will deploy on CPU. + +example_input = torch.randn(20, 1, 10, 10) + tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=2000, - max_trials_global=2000, + num_trials_per_iter=2, + max_trials_per_task=2, + max_trials_global=2, ) + # We use default configuration for the first example model_optimized_by_meta = optimize_torch( - SimpleModel(), example_input) + SimpleModel(), example_input, tuning_config) + +# ##################################################################### +# Optimized SimpleModel by TorchScript +# ------------------------------ +# As a comparison, we trace this module via `optimize_for_inference` function from TorchScript +model_optimized_by_jit = torch.jit.optimize_for_inference( + torch.jit.trace(SimpleModel(), example_input)) + ###################################################################### # Compare the performance between two scheduling approaches. @@ -127,84 +127,3 @@ def forward(self, x): ret2 = model_loaded(example_input) tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) - -###################################################################### -# Define the resnet18 optimized by MetaSchedule -# ------------------------------ -# Another example, we compare the two optimizers about the performance of resnet18 -# For learning how to define a resnet18 model via PyTorch's nn.Module, -# you can refer to https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting - -# In our working machine, the GPU model is nvidia/geforce-rtx-3070. -target_cuda = "nvidia/geforce-rtx-3070" - -# For PyTorch users, you can write your nn.Module in a normal way. -# By applying "optimize_torch" function on the resnet18 model, we obtain a new resnet18 model optimized by MetaSchedule - - -class MyResNet18(torch.nn.Module): - def __init__(self, config, target=None): - super(MyResNet18, self).__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)).cuda() - self.resnet = optimize_torch( - resnet18(), [torch.rand(1, 3, 224, 224)], config, target) - - def forward(self, input): - return self.resnet(input - self.means) - - -# Since the setting of the number of trials is large, the initialization could be slow (sometimes more than 3 hours!) -meta_module_resnet18 = MyResNet18(tuning_config, target_cuda) - - -###################################################################### -# Define the resnet18 optimized by TorchScript -# ------------------------------ -# Besides, let us define a resnet18 model in a standard way. -# TorchScript also provide a built-in "optimize_for_inference" function to accelerate the inference. - -class JitModule(torch.nn.Module): - def __init__(self): - super(JitModule, self).__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)).cuda() - self.resnet = torch.jit.optimize_for_inference( - torch.jit.script(resnet18().cuda().eval())) - - def forward(self, input): - return self.resnet(input - self.means) - - -jit_module_resnet18 = JitModule() - -###################################################################### -# Compare the performance between two scheduling approaches. -# ------------------------------ -# Using PyTorch's benchmark Compare class, we can have a straightforward comparison between two inference models. - -results = [] -for i in range(20): - test_input = torch.rand(1, 3, 224, 224).half().cuda() - sub_label = f'[test {i}]' - results.append(benchmark.Timer( - stmt='meta_module_resnet18(test_input)', - setup='from __main__ import meta_module_resnet18', - globals={'test_input': test_input}, - sub_label=sub_label, - description='tuning by meta', - ).blocked_autorange()) - results.append(benchmark.Timer( - stmt='jit_module_resnet18(test_input)', - setup='from __main__ import jit_module_resnet18', - globals={'test_input': test_input}, - sub_label=sub_label, - description='tuning by jit', - ).blocked_autorange()) - -# We can print the results on screen. -compare = benchmark.Compare(results) -compare.print() - -# As above, we can save the module for future use -torch.save(meta_module_resnet18, "meta_tuned_resnet18.pt") diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index ce84ecd9a1ad..173aa2491291 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -49,9 +49,12 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule compile = pytorch_tvm.compile -from . import script_torch +from . import as_torch -as_torch = script_torch.as_torch -GraphExecutorFactoryWrapper = script_torch.GraphExecutorFactoryWrapper -TVMScriptIRModule = script_torch.TVMScriptIRModule -optimize_torch = script_torch.optimize_torch \ No newline at end of file +TVMScriptIRModule = as_torch.TVMScriptIRModule +as_torch = as_torch.as_torch + +from . import optimize_torch + +GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper +optimize_torch = optimize_torch.optimize_torch \ No newline at end of file diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py new file mode 100644 index 000000000000..7e3c155e7cd6 --- /dev/null +++ b/python/tvm/contrib/torch/as_torch.py @@ -0,0 +1,84 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable, List, Tuple, Union + +import torch +import torch.utils.dlpack + +import tvm + + +class TVMScriptIRModule(torch.nn.Module): + def __init__(self, module: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, tvm.contrib.graph_executor.GraphModule]): + super().__init__() + self.engine_cpu = None + self.engine_cuda = None + self.ir_module = module + + def __save_cpu_rt_module(self, runtime_module): + func = tvm.get_global_func("tvmtorch.save_runtime_mod") + func(runtime_module) + + self.engine_cpu = torch.classes.tvm_torch.TVMScriptRuntime() + + def build_cpu(self): + runtime_module = tvm.build(self.ir_module) + self.__save_cpu_rt_module(runtime_module) + + def __save_cuda_rt_module(self, runtime_module): + self.engine_cuda = runtime_module + + def build_cuda(self): + runtime_module = tvm.build(self.ir_module, target=tvm.target.cuda()) + self.__save_cuda_rt_module(runtime_module) + + def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: + if torch_inputs[0].is_cuda: + if self.engine_cuda is None: + self.build_cuda() + return self.engine_cuda.forward(torch_inputs) + else: + if self.engine_cpu is None: + self.build_cpu() + return self.engine_cpu.forward(torch_inputs) + + +def as_torch( + func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] +): + """A decorator of converting TensorIR to PyTorch nn.Module. + + Parameters + ---------- + func : Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] + The function to be parsed. + + + Returns + ------- + mod : TVMScriptIRModule + It will return an object of TVMScriptIRModule, which is the subclass of the original nn.Module. + """ + if isinstance(func, tvm.ir.module.IRModule) or isinstance(func, tvm.tir.function.PrimFunc): + return TVMScriptIRModule(func) + elif isinstance(func, Callable): + def func_get_param(*args, **kargs): + return TVMScriptIRModule(func(*args, **kargs)) + return func_get_param diff --git a/python/tvm/contrib/torch/script_torch.py b/python/tvm/contrib/torch/optimize_torch.py similarity index 57% rename from python/tvm/contrib/torch/script_torch.py rename to python/tvm/contrib/torch/optimize_torch.py index c1e811c802c6..eb5693dbf40f 100644 --- a/python/tvm/contrib/torch/script_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -19,10 +19,10 @@ # under the License. import base64 -import functools +import contextlib import tempfile from typing import Callable, Dict, Tuple, Union, List - +from tvm.meta_schedule.tune import tune_relay import torch import torch.utils.dlpack @@ -48,82 +48,10 @@ def forward(self, *torch_inputs: Tuple[torch.Tensor]): return ret -class TVMScriptIRModule(torch.nn.Module): - def __init__(self, module: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, tvm.contrib.graph_executor.GraphModule]): - super().__init__() - self.engine_cpu = None - self.engine_cuda = None - self.ir_module = module - - def __save_cpu_rt_module(self, runtime_module): - func = tvm.get_global_func("tvmtorch.save_runtime_mod") - func(runtime_module) - - self.engine_cpu = torch.classes.tvm_torch.TVMScriptRuntime() - - def __build_cpu(self): - runtime_module = tvm.build(self.ir_module) - self.__save_cpu_rt_module(runtime_module) - - def __save_cuda_rt_module(self, runtime_module): - self.engine_cuda = runtime_module - - def __build_cuda(self): - runtime_module = tvm.build(self.ir_module, target=tvm.target.cuda()) - self.__save_cuda_rt_module(runtime_module) - - def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: - if torch_inputs[0].is_cuda: - if self.engine_cuda is None: - self.__build_cuda() - return self.engine_cuda.forward(torch_inputs) - else: - if self.engine_cpu is None: - self.__build_cpu() - return self.engine_cpu.forward(torch_inputs) - - -def as_torch( - func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] -): - """A decorator of converting TensorIR to PyTorch nn.Module. - - Parameters - ---------- - func : Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] - The function to be parsed. - - - Returns - ------- - mod : TVMScriptIRModule - It will return an object of TVMScriptIRModule, which is the subclass of the original nn.Module. - """ - if isinstance(func, tvm.ir.module.IRModule) or isinstance(func, tvm.tir.function.PrimFunc): - return TVMScriptIRModule(func) - elif isinstance(func, Callable): - def func_get_param(*args, **kargs): - return TVMScriptIRModule(func(*args, **kargs)) - return func_get_param - - -@functools.lru_cache(None) def llvm_target(): return "llvm -num-cores" -def tuning_relay(mod: tvm.ir.module.IRModule, params: Dict, config: TuneConfig, target, work_dir: str = None): - from tvm.meta_schedule.tune import tune_relay - with tempfile.TemporaryDirectory() as tmp_work_dir: - return tune_relay( - mod=mod, - params=params, - target=target, - config=config, - work_dir=work_dir if work_dir else tmp_work_dir, - ) - - @register_func("script_torch.save_to_base64") def save_to_base64(obj) -> bytes: with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: @@ -177,9 +105,9 @@ def optimize_torch( # Default setting. For a better tuning result the number could be set larger. tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=4, - max_trials_per_task=16, - max_trials_global=16, + num_trials_per_iter=1, + max_trials_per_task=2, + max_trials_global=0, ) jit_mod = torch.jit.trace(func, example_inputs) @@ -188,8 +116,15 @@ def optimize_torch( shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule - executor_factory = tuning_relay( - mod, params, tuning_config, target, work_dir) + if work_dir: + cm = contextlib.nullcontext() + else: + cm = tempfile.TemporaryDirectory() + with cm as work_dir_path: + if work_dir is None: + work_dir = work_dir_path + executor_factory = tune_relay( + mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir) save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") save_runtime_mod(executor_factory.module) From 9de55278fc9a62af10a8a0bf3c9bdf5118ca4f16 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 22 Jun 2022 02:43:43 -0700 Subject: [PATCH 03/27] code formatting --- apps/pt_tvmdsoop/tests/test_optimize_torch.py | 2 +- python/tvm/contrib/torch/__init__.py | 2 +- python/tvm/contrib/torch/as_torch.py | 27 ++++----- python/tvm/contrib/torch/optimize_torch.py | 11 ++-- ...toryWrapper.cc => RuntimeModuleWrapper.cc} | 55 ++++++++++++++++--- 5 files changed, 69 insertions(+), 28 deletions(-) rename src/contrib/torch/pt_call_tvm/{ExecutorFactoryWrapper.cc => RuntimeModuleWrapper.cc} (74%) diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py index db58cd25c926..b1f83e94174f 100644 --- a/apps/pt_tvmdsoop/tests/test_optimize_torch.py +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -45,7 +45,7 @@ def matmul(x, w): w = torch.randn(20, 30) example_inputs = (x, w) - rt_mod = optimize_torch(matmul, example_inputs, config) + rt_mod = optimize_torch(matmul, example_inputs) torch_answer = torch.matmul(x, w).numpy() tvm_answer = rt_mod(x, w).numpy() diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 173aa2491291..1911c8388239 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -51,7 +51,7 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): from . import as_torch -TVMScriptIRModule = as_torch.TVMScriptIRModule +TVMScriptIRModule = as_torch.OperatorModuleWrapper as_torch = as_torch.as_torch from . import optimize_torch diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 7e3c155e7cd6..99a2a7b023cb 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -25,29 +25,24 @@ import tvm -class TVMScriptIRModule(torch.nn.Module): +class OperatorModuleWrapper(torch.nn.Module): def __init__(self, module: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, tvm.contrib.graph_executor.GraphModule]): super().__init__() self.engine_cpu = None self.engine_cuda = None self.ir_module = module - def __save_cpu_rt_module(self, runtime_module): - func = tvm.get_global_func("tvmtorch.save_runtime_mod") - func(runtime_module) - - self.engine_cpu = torch.classes.tvm_torch.TVMScriptRuntime() - def build_cpu(self): runtime_module = tvm.build(self.ir_module) - self.__save_cpu_rt_module(runtime_module) + func = tvm.get_global_func("tvmtorch.save_runtime_mod") + func(runtime_module) - def __save_cuda_rt_module(self, runtime_module): - self.engine_cuda = runtime_module + self.engine_cpu = torch.classes.tvm_torch.OperatorModuleWrapper() def build_cuda(self): + # If the module build on cuda, we won't call the C++ code since some information is missing runtime_module = tvm.build(self.ir_module, target=tvm.target.cuda()) - self.__save_cuda_rt_module(runtime_module) + self.engine_cuda = runtime_module def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: if torch_inputs[0].is_cuda: @@ -55,6 +50,8 @@ def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: self.build_cuda() return self.engine_cuda.forward(torch_inputs) else: + # We force the tensor inputs to be on cpu. + torch_inputs = tuple(map(lambda x: x.cpu(), torch_inputs)) if self.engine_cpu is None: self.build_cpu() return self.engine_cpu.forward(torch_inputs) @@ -73,12 +70,12 @@ def as_torch( Returns ------- - mod : TVMScriptIRModule - It will return an object of TVMScriptIRModule, which is the subclass of the original nn.Module. + mod : OperatorModuleWrapper + It will return an object of OperatorModuleWrapper, which is the subclass of the original nn.Module. """ if isinstance(func, tvm.ir.module.IRModule) or isinstance(func, tvm.tir.function.PrimFunc): - return TVMScriptIRModule(func) + return OperatorModuleWrapper(func) elif isinstance(func, Callable): def func_get_param(*args, **kargs): - return TVMScriptIRModule(func(*args, **kargs)) + return OperatorModuleWrapper(func(*args, **kargs)) return func_get_param diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index eb5693dbf40f..0bb734016757 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -105,14 +105,17 @@ def optimize_torch( # Default setting. For a better tuning result the number could be set larger. tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=1, - max_trials_per_task=2, - max_trials_global=0, + num_trials_per_iter=4, + max_trials_per_task=16, + max_trials_global=16, ) + # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) + if isinstance(example_inputs, torch.Tensor): example_inputs = [example_inputs] + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule @@ -129,4 +132,4 @@ def optimize_torch( save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") save_runtime_mod(executor_factory.module) - return GraphExecutorFactoryWrapper(torch.classes.tvm_tuning.RelayRuntime()) + return GraphExecutorFactoryWrapper(torch.classes.tvm_tuning.GraphExecutorFactoryWrapper()) diff --git a/src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc similarity index 74% rename from src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc rename to src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 9adc771c1840..501fbbc7f48d 100644 --- a/src/contrib/torch/pt_call_tvm/ExecutorFactoryWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -44,9 +44,9 @@ struct ThreadLocalStore { } }; -class TVMScriptRuntimeClass : public torch::jit::CustomClassHolder { +class OperatorModuleWrapper : public torch::jit::CustomClassHolder { public: - TVMScriptRuntimeClass() { mod_ = ThreadLocalStore::ThreadLocal()->mod; } + OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } void forward(const c10::List& inputs) { int input_length = inputs.size(); @@ -55,7 +55,7 @@ class TVMScriptRuntimeClass : public torch::jit::CustomClassHolder { for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); - tvm::runtime::PackedFunc run = mod_.GetFunction("__tvm_main__"); + tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__"); std::vector tvm_values(input_length); std::vector tvm_type_codes(input_length); @@ -72,8 +72,44 @@ class TVMScriptRuntimeClass : public torch::jit::CustomClassHolder { } } + using SerializationType = std::string; // executor factory stream + + SerializationType Serialize() { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(runtime_module); + } + + OperatorModuleWrapper(SerializationType state) { + auto length = tvm::support::b64strlen(state); + + u_char bytes[length]; + memset(bytes, 0, sizeof(bytes)); + tvm::support::b64decode(state, bytes); + + const char* name = tmpnam(NULL); + auto file_name = std::string(name) + ".so"; + auto pFile = fopen(file_name.c_str(), "wb"); + fwrite(bytes, sizeof(u_char), length, pFile); + fclose(pFile); + + std::string load_f_name = "runtime.module.loadfile_so"; + const PackedFunc* f = runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + + runtime_module = (*f)(file_name, ""); + + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + } + private: - tvm::runtime::Module mod_; + tvm::runtime::Module runtime_module; }; tvm::Device getDevice(const at::Tensor& tensor) { @@ -193,13 +229,18 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: }); TORCH_LIBRARY(tvm_torch, m) { - m.class_("TVMScriptRuntime") + m.class_("OperatorModuleWrapper") .def(torch::init<>()) - .def("forward", &TVMScriptRuntimeClass::forward); + .def("forward", &OperatorModuleWrapper::forward) + .def_pickle([](const c10::intrusive_ptr& self) + -> OperatorModuleWrapper::SerializationType { return self->Serialize(); }, + [](OperatorModuleWrapper::SerializationType state) { + return c10::make_intrusive(state); + }); } TORCH_LIBRARY(tvm_tuning, m) { - m.class_("RelayRuntime") + m.class_("GraphExecutorFactoryWrapper") .def(torch::init<>()) .def("forward", &GraphExecutorFactoryWrapper::forward) .def_pickle( From e0703a6f9b4a1ffc56e00d58755e8163f6c6eecf Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 22 Jun 2022 20:36:27 -0700 Subject: [PATCH 04/27] optimizing optimized_torch --- apps/pt_tvmdsoop/tests/test_as_torch.py | 12 +- apps/pt_tvmdsoop/tests/test_optimize_torch.py | 18 +-- .../work_with_pytorch/optimize_torch.py | 108 +++++++++++++----- python/tvm/contrib/torch/optimize_torch.py | 12 +- 4 files changed, 97 insertions(+), 53 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 4cda9be4fe00..2b1185681493 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -83,9 +83,9 @@ def test_tvmscript_torch_matmul(): numpy_result = np.matmul(s1, s2) - tvm_module = matmul(128, 128, 128, "float32") + nn_module = matmul(128, 128, 128, "float32") - tvm_module(q1, q2, q3) + nn_module(q1, q2, q3) tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) @@ -98,9 +98,7 @@ def test_tvmscript_torch_decorator(): numpy_result = s1 + 1 - tvm_module = MyModule - - tvm_module(q1, q2) + MyModule(q1, q2) tvm.testing.assert_allclose(q2.numpy(), numpy_result, atol=1e-5, rtol=1e-5) @@ -111,9 +109,9 @@ def test_torch_with_tvmscirpt(): q1 = torch.arange(8).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32) - tvm_module = MinuesOnes() + nn_module = MinuesOnes() - ret = tvm_module.forward(q1, q2) + ret = nn_module.forward(q1, q2) tvm.testing.assert_allclose(ret.numpy(), s1, atol=1e-5, rtol=1e-5) diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py index b1f83e94174f..451fd80d552f 100644 --- a/apps/pt_tvmdsoop/tests/test_optimize_torch.py +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -32,9 +32,9 @@ # default config for testing config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=2000, - max_trials_global=2000, + num_trials_per_iter=2, + max_trials_per_task=4, + max_trials_global=0, ) def test_matmul_tuning_relay(): @@ -114,9 +114,10 @@ def __init__(self): def forward(self, input): return self.resnet(input - self.means) -target_cuda = "nvidia/geforce-rtx-3070" -meta_module_resnet18 = MyResNet18(config, target_cuda) -jit_module_resnet18 = JitModule() +if torch.cuda.is_available(): + target_cuda = "nvidia/geforce-rtx-3070" + meta_module_resnet18 = MyResNet18(config, target_cuda) + jit_module_resnet18 = JitModule() def compare_optimize_resnet18_to_torchscript(): results = [] @@ -141,8 +142,9 @@ def compare_optimize_resnet18_to_torchscript(): compare.print() if __name__ == "__main__": - test_matmul_tuning_relay() + test_matmul_tuning_relay() test_nested_module() test_save_load_function() - compare_optimize_resnet18_to_torchscript() + if torch.cuda.is_available(): + compare_optimize_resnet18_to_torchscript() \ No newline at end of file diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py index 0d759668de2b..a033893259be 100644 --- a/gallery/how_to/work_with_pytorch/optimize_torch.py +++ b/gallery/how_to/work_with_pytorch/optimize_torch.py @@ -35,6 +35,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torchvision.models import resnet18 # Import library for profiling import torch.utils.benchmark as benchmark @@ -64,25 +65,83 @@ def forward(self, x): example_input = torch.randn(20, 1, 10, 10) +# We use default configuration for the first example +model_optimized_by_meta = optimize_torch( + SimpleModel(), example_input) + +###################################################################### +# Save/Load module +# ------------------------------ +# We can save and load our tuned module like standard nn.module + +# Let us run our tuned module and see the result +ret1 = model_optimized_by_meta(example_input) + +torch.save(model_optimized_by_meta, "meta_model.pt") +model_loaded = torch.load("meta_model.pt") + +# We load the module and run again and it will return the same result as above. +ret2 = model_loaded(example_input) + +tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) + +###################################################################### +# Define the resnet18 optimized by MetaSchedule +# ------------------------------ +# Another example, we compare the two optimizers about the performance of resnet18 +# For learning how to define a resnet18 model via PyTorch's nn.Module, +# you can refer to https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting + +# In our working machine, the GPU model is nvidia/geforce-rtx-3070. +target_cuda = "nvidia/geforce-rtx-3070" + +# We can define the configuration by ourselves tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=2, - max_trials_per_task=2, - max_trials_global=2, + num_trials_per_iter=1, + max_trials_per_task=1, + max_trials_global=0, ) +# For PyTorch users, you can write your nn.Module in a normal way. +# By applying "optimize_torch" function on the resnet18 model, we obtain a new resnet18 model optimized by MetaSchedule -# We use default configuration for the first example -model_optimized_by_meta = optimize_torch( - SimpleModel(), example_input, tuning_config) -# ##################################################################### -# Optimized SimpleModel by TorchScript +class MyResNet18(torch.nn.Module): + def __init__(self, config, target=None): + super(MyResNet18, self).__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)).cuda() + self.resnet = optimize_torch( + resnet18(), [torch.rand(1, 3, 224, 224)], config, target) + + def forward(self, input): + return self.resnet(input - self.means) + + +# Since the setting of the number of trials is large, the initialization could be slow (sometimes more than 3 hours!) +meta_module_resnet18 = MyResNet18(tuning_config, target_cuda) + + +###################################################################### +# Define the resnet18 optimized by TorchScript # ------------------------------ -# As a comparison, we trace this module via `optimize_for_inference` function from TorchScript -model_optimized_by_jit = torch.jit.optimize_for_inference( - torch.jit.trace(SimpleModel(), example_input)) +# Besides, let us define a resnet18 model in a standard way. +# TorchScript also provide a built-in "optimize_for_inference" function to accelerate the inference. +class JitModule(torch.nn.Module): + def __init__(self): + super(JitModule, self).__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)).cuda() + self.resnet = torch.jit.optimize_for_inference( + torch.jit.script(resnet18().cuda().eval())) + + def forward(self, input): + return self.resnet(input - self.means) + + +jit_module_resnet18 = JitModule() ###################################################################### # Compare the performance between two scheduling approaches. @@ -91,18 +150,18 @@ def forward(self, x): results = [] for i in range(20): - test_input = torch.rand(20, 1, 10, 10) + test_input = torch.rand(1, 3, 224, 224).half().cuda() sub_label = f'[test {i}]' results.append(benchmark.Timer( - stmt='model_optimized_by_meta(test_input)', - setup='from __main__ import model_optimized_by_meta', + stmt='meta_module_resnet18(test_input)', + setup='from __main__ import meta_module_resnet18', globals={'test_input': test_input}, sub_label=sub_label, description='tuning by meta', ).blocked_autorange()) results.append(benchmark.Timer( - stmt='model_optimized_by_jit(test_input)', - setup='from __main__ import model_optimized_by_jit', + stmt='jit_module_resnet18(test_input)', + setup='from __main__ import jit_module_resnet18', globals={'test_input': test_input}, sub_label=sub_label, description='tuning by jit', @@ -112,18 +171,5 @@ def forward(self, x): compare = benchmark.Compare(results) compare.print() -###################################################################### -# Save/Load module -# ------------------------------ -# We can save and load our tuned module like standard nn.module - -# Let us run our tuned module and see the result -ret1 = model_optimized_by_meta(example_input) - -torch.save(model_optimized_by_meta, "meta_model.pt") -model_loaded = torch.load("meta_model.pt") - -# We load the module and run again and it will return the same result as above. -ret2 = model_loaded(example_input) - -tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) +# As above, we can save the module for future use +torch.save(meta_module_resnet18, "meta_tuned_resnet18.pt") diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 0bb734016757..378dfa041f0d 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -105,9 +105,9 @@ def optimize_torch( # Default setting. For a better tuning result the number could be set larger. tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=4, - max_trials_per_task=16, - max_trials_global=16, + num_trials_per_iter=1, + max_trials_per_task=4, + max_trials_global=0, ) # If `func` is already a traced module this statement makes no effect @@ -120,14 +120,12 @@ def optimize_torch( for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule if work_dir: - cm = contextlib.nullcontext() + cm = contextlib.nullcontext(work_dir) else: cm = tempfile.TemporaryDirectory() with cm as work_dir_path: - if work_dir is None: - work_dir = work_dir_path executor_factory = tune_relay( - mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir) + mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path) save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") save_runtime_mod(executor_factory.module) From e9622b3d864c951a4705829cee2c6cb612c0ad6c Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 22 Jun 2022 20:37:37 -0700 Subject: [PATCH 05/27] scrap your boilerplate --- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 142 ++++++++---------- 1 file changed, 64 insertions(+), 78 deletions(-) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 501fbbc7f48d..029a952ac305 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -31,6 +31,7 @@ #include #include +#include "../../../runtime/graph_executor/graph_executor_factory.h" #include "../base64.h" namespace tvm { @@ -44,6 +45,44 @@ struct ThreadLocalStore { } }; +using SerializationType = std::string; // executor factory stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(module); +} + +tvm::runtime::Module deserialize(SerializationType state) { + auto length = tvm::support::b64strlen(state); + + u_char bytes[length]; + memset(bytes, 0, sizeof(bytes)); + tvm::support::b64decode(state, bytes); + + const std::string name = tmpnam(NULL); + auto file_name = name + ".so"; + auto pFile = fopen(file_name.c_str(), "wb"); + fwrite(bytes, sizeof(u_char), length, pFile); + fclose(pFile); + + std::string load_f_name = "runtime.module.loadfile_so"; + const PackedFunc* f = runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + + tvm::runtime::Module ret = (*f)(file_name, ""); + + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + + return ret; +} + class OperatorModuleWrapper : public torch::jit::CustomClassHolder { public: OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } @@ -72,41 +111,9 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { } } - using SerializationType = std::string; // executor factory stream - - SerializationType Serialize() { - static const runtime::PackedFunc* f_to_str = - runtime::Registry::Get("script_torch.save_to_base64"); - ICHECK(f_to_str) << "IndexError: Cannot find the packed function " - "`script_torch.save_to_tar` in the global registry"; - return (*f_to_str)(runtime_module); - } - - OperatorModuleWrapper(SerializationType state) { - auto length = tvm::support::b64strlen(state); - - u_char bytes[length]; - memset(bytes, 0, sizeof(bytes)); - tvm::support::b64decode(state, bytes); + SerializationType Serialize() { return serialize(runtime_module); } - const char* name = tmpnam(NULL); - auto file_name = std::string(name) + ".so"; - auto pFile = fopen(file_name.c_str(), "wb"); - fwrite(bytes, sizeof(u_char), length, pFile); - fclose(pFile); - - std::string load_f_name = "runtime.module.loadfile_so"; - const PackedFunc* f = runtime::Registry::Get(load_f_name); - ICHECK(f != nullptr) << "Loader for `.so` files is not registered," - << " resolved to (" << load_f_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - - runtime_module = (*f)(file_name, ""); - - ICHECK(remove(file_name.c_str()) == 0) - << "remove temporary file (" << file_name << ") unsuccessfully"; - } + OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } private: tvm::runtime::Module runtime_module; @@ -119,7 +126,12 @@ tvm::Device getDevice(const at::Tensor& tensor) { case at::DeviceType::CPU: dev.device_type = DLDeviceType::kDLCPU; if (dev.device_id == -1) { - dev.device_id = 1; + /* + * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning + * Thus we manually set the device ID as 0 for avoding potentially error of index out of + * bounds + */ + dev.device_id = 0; } break; case at::DeviceType::CUDA: @@ -134,7 +146,10 @@ tvm::Device getDevice(const at::Tensor& tensor) { class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { public: GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) - : executor_factory_(executor_factory) {} + : executor_factory_(executor_factory) { + CHECK(executor_factory_->IsInstance()) + << "module is not an instance of GraphExecutorFactory"; + } GraphExecutorFactoryWrapper() : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {} @@ -183,41 +198,9 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { return outputs; } - using SerializationType = std::string; // executor factory stream - - SerializationType Serialize() { - static const runtime::PackedFunc* f_to_str = - runtime::Registry::Get("script_torch.save_to_base64"); - ICHECK(f_to_str) << "IndexError: Cannot find the packed function " - "`script_torch.save_to_tar` in the global registry"; - return (*f_to_str)(executor_factory_); - } - - GraphExecutorFactoryWrapper(SerializationType state) { - auto length = tvm::support::b64strlen(state); + SerializationType Serialize() { return serialize(executor_factory_); } - u_char bytes[length]; - memset(bytes, 0, sizeof(bytes)); - tvm::support::b64decode(state, bytes); - - const char* name = tmpnam(NULL); - auto file_name = std::string(name) + ".so"; - auto pFile = fopen(file_name.c_str(), "wb"); - fwrite(bytes, sizeof(u_char), length, pFile); - fclose(pFile); - - std::string load_f_name = "runtime.module.loadfile_so"; - const PackedFunc* f = runtime::Registry::Get(load_f_name); - ICHECK(f != nullptr) << "Loader for `.so` files is not registered," - << " resolved to (" << load_f_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - - executor_factory_ = (*f)(file_name, ""); - - ICHECK(remove(file_name.c_str()) == 0) - << "remove temporary file (" << file_name << ") unsuccessfully"; - } + GraphExecutorFactoryWrapper(SerializationType state) { executor_factory_ = deserialize(state); } private: tvm::runtime::Module executor_factory_; @@ -232,11 +215,13 @@ TORCH_LIBRARY(tvm_torch, m) { m.class_("OperatorModuleWrapper") .def(torch::init<>()) .def("forward", &OperatorModuleWrapper::forward) - .def_pickle([](const c10::intrusive_ptr& self) - -> OperatorModuleWrapper::SerializationType { return self->Serialize(); }, - [](OperatorModuleWrapper::SerializationType state) { - return c10::make_intrusive(state); - }); + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializationType { + return self->Serialize(); + }, + [](SerializationType state) { + return c10::make_intrusive(state); + }); } TORCH_LIBRARY(tvm_tuning, m) { @@ -244,9 +229,10 @@ TORCH_LIBRARY(tvm_tuning, m) { .def(torch::init<>()) .def("forward", &GraphExecutorFactoryWrapper::forward) .def_pickle( - [](const c10::intrusive_ptr& self) - -> GraphExecutorFactoryWrapper::SerializationType { return self->Serialize(); }, - [](GraphExecutorFactoryWrapper::SerializationType state) { + [](const c10::intrusive_ptr& self) -> SerializationType { + return self->Serialize(); + }, + [](SerializationType state) { return c10::make_intrusive(state); }); } From 763806c537c442174ccbd5bcade3638eeed6cdfa Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 23 Jun 2022 00:18:15 -0700 Subject: [PATCH 06/27] as_torch polished --- apps/pt_tvmdsoop/tests/test_as_torch.py | 32 ++++++- gallery/how_to/work_with_pytorch/as_torch.py | 93 ++++++++++++-------- python/tvm/contrib/torch/as_torch.py | 35 +++----- 3 files changed, 100 insertions(+), 60 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 2b1185681493..ebdf02b5d6a4 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -59,6 +59,20 @@ def main(a: T.handle, b: T.handle): vi = T.axis.spatial(8, i) B[vi] = A[vi] + 1.0 +@as_torch +@tvm.script.ir_module +class ModuleGPU: + @T.prim_func + def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_2 in T.thread_binding(2, thread="threadIdx.x"): + for i_1 in T.serial(2): + with T.block("B"): + vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) + T.reads(A[vi]) + T.writes(B[vi]) + B[vi] = A[vi] + T.float32(1) class MinuesOnes(torch.nn.Module): def __init__(self): @@ -101,9 +115,22 @@ def test_tvmscript_torch_decorator(): MyModule(q1, q2) tvm.testing.assert_allclose(q2.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + +def test_tvmscript_torch_gpu(): + s1 = np.arange(8).astype("float32") + + cuda0 = torch.device('cuda:0') + q1 = torch.arange(8, device=cuda0).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) + + numpy_result = s1 + 1 + + ModuleGPU(q1, q2) + + tvm.testing.assert_allclose(q2.cpu().numpy(), numpy_result, atol=1e-5, rtol=1e-5) -def test_torch_with_tvmscirpt(): +def test_torch_with_tvmscript(): s1 = np.arange(8).astype("float32") q1 = torch.arange(8).type(torch.float32) @@ -119,4 +146,5 @@ def test_torch_with_tvmscirpt(): if __name__ == "__main__": test_tvmscript_torch_matmul() test_tvmscript_torch_decorator() - test_torch_with_tvmscirpt() + test_tvmscript_torch_gpu() + test_torch_with_tvmscript() diff --git a/gallery/how_to/work_with_pytorch/as_torch.py b/gallery/how_to/work_with_pytorch/as_torch.py index 0d88ada429a6..739b615f89dc 100644 --- a/gallery/how_to/work_with_pytorch/as_torch.py +++ b/gallery/how_to/work_with_pytorch/as_torch.py @@ -27,12 +27,12 @@ """ # Import Tvm and PyTorch, as well as necessary libraries import tvm +import tvm.testing +import numpy as np import torch +import torch.nn from tvm.contrib.torch import as_torch from tvm.script import tir as T -import numpy as np -import torch.nn -import tvm.testing ###################################################################### @@ -58,29 +58,19 @@ def main(a: T.handle, b: T.handle): # ------------------------------- -def test_tvmscript_torch_decorator(): - s1 = np.arange(8).astype("float32") - - # Define two torch tensors - q1 = torch.arange(8).type(torch.float32) - q2 = torch.zeros((8,), dtype=torch.float32) - - # Result from numpy - numpy_result = s1 + 1 - - tvm_module = MyModule +# Define two torch tensors +q1 = torch.arange(8).type(torch.float32) +q2 = torch.zeros((8,), dtype=torch.float32) - # We call `MyModule` as PyTorch module's forward - tvm_module(q1, q2) +# Call the function directly, the result is stored at `q2` +MyModule(q1, q2) - # Testing. No output implies that tensors are equal - tvm.testing.assert_allclose(q2.numpy(), numpy_result, atol=1e-5, rtol=1e-5) +# Testing. No output implies that tensors are equal +tvm.testing.assert_allclose(q2.numpy(), (q1+1).numpy(), atol=1e-5, rtol=1e-5) -test_tvmscript_torch_decorator() - ###################################################################### -# Another example: matrix multiplication with a limit form meta-programming +# The second example: matrix multiplication with a limit form meta-programming # ------------------------------- # As above, we can add `as_torch` decorator to a Tensor IR function. @@ -105,26 +95,55 @@ def f(a: T.handle, b: T.handle, c: T.handle) -> None: # ------------------------------- -def test_tvmscript_torch_matmul(): - # Create two 128 x 128 matrixs as input - s1 = np.random.rand(128, 128).astype("float32") - s2 = np.random.rand(128, 128).astype("float32") - s3 = np.zeros((128, 128)).astype("float32") +# Create two 128 x 128 matrixes as input +s1 = np.random.rand(128, 128).astype("float32") +s2 = np.random.rand(128, 128).astype("float32") +s3 = np.zeros((128, 128)).astype("float32") + +q1 = torch.from_numpy(s1) +q2 = torch.from_numpy(s2) +q3 = torch.from_numpy(s3) - q1 = torch.from_numpy(s1) - q2 = torch.from_numpy(s2) - q3 = torch.from_numpy(s3) +# Result from numpy +numpy_result = np.matmul(s1, np.transpose(s2)) - # Result from numpy - numpy_result = np.matmul(s1, np.transpose(s2)) +# Instantiate the `matmul` function by passing the parameters of shapes and datatype +tvm_module = matmul(128, 128, 128, "float32") - # Instantiate the `matmul` function by passing the parameters of shapes and datatype - tvm_module = matmul(128, 128, 128, "float32") +# Run the operator +tvm_module(q1, q2, q3) - # Run the operator - tvm_module(q1, q2, q3) +tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) - tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) +###################################################################### +# Last example: GPU supporting +# ------------------------------- +# In such an example, we demonstrate our method does support module built upon GPU +# The code below is the GPU version of `MyModule` -test_tvmscript_torch_matmul() +@as_torch +@tvm.script.ir_module +class ModuleGPU: + @T.prim_func + def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_2 in T.thread_binding(2, thread="threadIdx.x"): + for i_1 in T.serial(2): + with T.block("B"): + vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) + T.reads(A[vi]) + T.writes(B[vi]) + B[vi] = A[vi] + T.float32(1) + + +# Define two torch tensors, on GPU +cuda0 = torch.device('cuda:0') +q1 = torch.arange(8, device=cuda0).type(torch.float32) +q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) + +# Running and testing +ModuleGPU(q1, q2) +tvm.testing.assert_allclose( + q2.cpu().numpy(), (q1+1).cpu().numpy(), atol=1e-5, rtol=1e-5) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 99a2a7b023cb..49e2be6b390e 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -28,33 +28,26 @@ class OperatorModuleWrapper(torch.nn.Module): def __init__(self, module: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, tvm.contrib.graph_executor.GraphModule]): super().__init__() - self.engine_cpu = None - self.engine_cuda = None - self.ir_module = module + self.rt_module = None # runtime module + self.ir_module = module # IR moudle - def build_cpu(self): - runtime_module = tvm.build(self.ir_module) + def build(self, target = None): + runtime_module = tvm.build(self.ir_module, target = target) func = tvm.get_global_func("tvmtorch.save_runtime_mod") func(runtime_module) - self.engine_cpu = torch.classes.tvm_torch.OperatorModuleWrapper() - - def build_cuda(self): - # If the module build on cuda, we won't call the C++ code since some information is missing - runtime_module = tvm.build(self.ir_module, target=tvm.target.cuda()) - self.engine_cuda = runtime_module + self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper() def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: - if torch_inputs[0].is_cuda: - if self.engine_cuda is None: - self.build_cuda() - return self.engine_cuda.forward(torch_inputs) - else: - # We force the tensor inputs to be on cpu. - torch_inputs = tuple(map(lambda x: x.cpu(), torch_inputs)) - if self.engine_cpu is None: - self.build_cpu() - return self.engine_cpu.forward(torch_inputs) + if self.rt_module is None: + if torch_inputs[0].is_cuda: + self.build(target = "cuda") + elif torch_inputs[0].device.type == "cpu": + self.build() + else: + raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet") + + return self.rt_module.forward(torch_inputs) def as_torch( From 1068f7d01473d3398ef8c2f10d4d0784458f615b Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 23 Jun 2022 00:23:00 -0700 Subject: [PATCH 07/27] configuration fixed --- apps/pt_tvmdsoop/tests/test_as_torch.py | 16 ++++------------ apps/pt_tvmdsoop/tests/test_optimize_torch.py | 6 +++--- .../how_to/work_with_pytorch/optimize_torch.py | 6 +++--- python/tvm/contrib/torch/optimize_torch.py | 6 +++--- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index ebdf02b5d6a4..b84d9fbc030a 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -105,33 +105,25 @@ def test_tvmscript_torch_matmul(): def test_tvmscript_torch_decorator(): - s1 = np.arange(8).astype("float32") - q1 = torch.arange(8).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32) - numpy_result = s1 + 1 - MyModule(q1, q2) - tvm.testing.assert_allclose(q2.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(q2.numpy(), (q1+1).numpy(), atol=1e-5, rtol=1e-5) def test_tvmscript_torch_gpu(): - s1 = np.arange(8).astype("float32") - cuda0 = torch.device('cuda:0') q1 = torch.arange(8, device=cuda0).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) - numpy_result = s1 + 1 - ModuleGPU(q1, q2) - tvm.testing.assert_allclose(q2.cpu().numpy(), numpy_result, atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(q2.cpu().numpy(), (q1+1).cpu().numpy(), atol=1e-5, rtol=1e-5) def test_torch_with_tvmscript(): - s1 = np.arange(8).astype("float32") + ref_result = np.arange(8).astype("float32") q1 = torch.arange(8).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32) @@ -140,7 +132,7 @@ def test_torch_with_tvmscript(): ret = nn_module.forward(q1, q2) - tvm.testing.assert_allclose(ret.numpy(), s1, atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(ret.numpy(), ref_result, atol=1e-5, rtol=1e-5) if __name__ == "__main__": diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py index 451fd80d552f..5c68233d654d 100644 --- a/apps/pt_tvmdsoop/tests/test_optimize_torch.py +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -32,9 +32,9 @@ # default config for testing config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=2, - max_trials_per_task=4, - max_trials_global=0, + num_trials_per_iter=4, + max_trials_per_task=8, + max_trials_global=16, ) def test_matmul_tuning_relay(): diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py index a033893259be..eaf7eefec674 100644 --- a/gallery/how_to/work_with_pytorch/optimize_torch.py +++ b/gallery/how_to/work_with_pytorch/optimize_torch.py @@ -98,9 +98,9 @@ def forward(self, x): # We can define the configuration by ourselves tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=1, - max_trials_per_task=1, - max_trials_global=0, + num_trials_per_iter=64, + max_trials_per_task=20000, + max_trials_global=20000, ) # For PyTorch users, you can write your nn.Module in a normal way. diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 378dfa041f0d..eb8ab97871c2 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -105,9 +105,9 @@ def optimize_torch( # Default setting. For a better tuning result the number could be set larger. tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=1, - max_trials_per_task=4, - max_trials_global=0, + num_trials_per_iter=4, + max_trials_per_task=8, + max_trials_global=16, ) # If `func` is already a traced module this statement makes no effect From 76104b6f0f9473e8034d2bdb13c279504e4ccf77 Mon Sep 17 00:00:00 2001 From: Yaoda Zhou Date: Fri, 24 Jun 2022 09:01:49 +0800 Subject: [PATCH 08/27] Apply suggestions from code review Co-authored-by: Lite Ye --- gallery/how_to/work_with_pytorch/optimize_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py index eaf7eefec674..17f7b23b9ec7 100644 --- a/gallery/how_to/work_with_pytorch/optimize_torch.py +++ b/gallery/how_to/work_with_pytorch/optimize_torch.py @@ -18,7 +18,7 @@ Compile PyTorch Models ====================== **Author**: `Yaoda Zhou `_ -This article is an introductory tutorial to optimize PyTorch models by MetaSchedule. +This article is an introductory tutorial to optimize PyTorch models by using `tvm.contrib.torch.optimize_torch`. For us to follow this tutorial, PyTorch as well as TorchVision should be installed. For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as .. code-block:: bash From 9acc6481afada81a25b66921505584c4555688a7 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 23 Jun 2022 18:26:08 -0700 Subject: [PATCH 09/27] more document --- gallery/how_to/work_with_pytorch/as_torch.py | 2 ++ .../how_to/work_with_pytorch/optimize_torch.py | 2 ++ python/tvm/contrib/torch/optimize_torch.py | 10 ++++++---- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 17 +++++++++++++++-- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/gallery/how_to/work_with_pytorch/as_torch.py b/gallery/how_to/work_with_pytorch/as_torch.py index 739b615f89dc..04b317c296d8 100644 --- a/gallery/how_to/work_with_pytorch/as_torch.py +++ b/gallery/how_to/work_with_pytorch/as_torch.py @@ -24,6 +24,8 @@ For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as .. code-block:: bash conda install -c conda-forge pytorch-gpu + python -c "import torch;print(torch.compiled_with_cxx11_abi())" +If the output printed in terminal is `True` then PyTorch built with Cxx11 ABI is successfully installed. """ # Import Tvm and PyTorch, as well as necessary libraries import tvm diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py index 17f7b23b9ec7..dfcd982bcc83 100644 --- a/gallery/how_to/work_with_pytorch/optimize_torch.py +++ b/gallery/how_to/work_with_pytorch/optimize_torch.py @@ -23,6 +23,8 @@ For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as .. code-block:: bash conda install -c conda-forge pytorch-gpu + python -c "import torch;print(torch.compiled_with_cxx11_abi())" +If the output printed in terminal is `True` then PyTorch built with Cxx11 ABI is successfully installed. """ # Import TVM import tvm diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index eb8ab97871c2..ca461af3da6b 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -80,6 +80,8 @@ def optimize_torch( tuning_config : tvm.meta_schedule.TuneConfig The configuration of tuning by MetaSchedule. + We suggest users to provide their own setting, otherwise by default setting a tuning process could be very slow, + sometimes costs a few hours. target : Optional[Union[str, Target]] The target of the compilation. @@ -102,12 +104,12 @@ def optimize_torch( if tuning_config: pass else: - # Default setting. For a better tuning result the number could be set larger. + # Default setting. For a better tuning result the number could be set large. tuning_config = TuneConfig( strategy="evolutionary", - num_trials_per_iter=4, - max_trials_per_task=8, - max_trials_global=16, + num_trials_per_iter=64, + max_trials_per_task=2000, + max_trials_global=2000, ) # If `func` is already a traced module this statement makes no effect diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 029a952ac305..7bfe6f4800c3 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -37,6 +37,9 @@ namespace tvm { namespace contrib { +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ struct ThreadLocalStore { tvm::runtime::Module mod; static ThreadLocalStore* ThreadLocal() { @@ -45,7 +48,7 @@ struct ThreadLocalStore { } }; -using SerializationType = std::string; // executor factory stream +using SerializationType = std::string; // base64 stream SerializationType serialize(tvm::runtime::Module module) { static const runtime::PackedFunc* f_to_str = @@ -83,6 +86,11 @@ tvm::runtime::Module deserialize(SerializationType state) { return ret; } +/** + * @brief A Torch's module which wraps TVM's OperatorModule Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { public: OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } @@ -128,7 +136,7 @@ tvm::Device getDevice(const at::Tensor& tensor) { if (dev.device_id == -1) { /* * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning - * Thus we manually set the device ID as 0 for avoding potentially error of index out of + * Thus we manually set the device ID as 0 for avoiding potentially error of index out of * bounds */ dev.device_id = 0; @@ -143,6 +151,11 @@ tvm::Device getDevice(const at::Tensor& tensor) { return dev; } +/** + * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { public: GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) From d594c488e5762d67464cb6f459597bffd54256e4 Mon Sep 17 00:00:00 2001 From: juda Date: Sun, 26 Jun 2022 18:25:17 -0700 Subject: [PATCH 10/27] file deleter --- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 7bfe6f4800c3..ef5a96dcde7e 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -58,6 +58,17 @@ SerializationType serialize(tvm::runtime::Module module) { return (*f_to_str)(module); } +struct Deleter { // deleter + Deleter(std::string file_name) { + this->file_name = file_name; + }; + void operator()(FILE* p) const { + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + }; + std::string file_name; +}; + tvm::runtime::Module deserialize(SerializationType state) { auto length = tvm::support::b64strlen(state); @@ -67,9 +78,10 @@ tvm::runtime::Module deserialize(SerializationType state) { const std::string name = tmpnam(NULL); auto file_name = name + ".so"; - auto pFile = fopen(file_name.c_str(), "wb"); - fwrite(bytes, sizeof(u_char), length, pFile); - fclose(pFile); + std::unique_ptr pFile( + fopen(file_name.c_str(), "wb"), Deleter(file_name)); + fwrite(bytes, sizeof(u_char), length, pFile.get()); + fclose(pFile.get()); std::string load_f_name = "runtime.module.loadfile_so"; const PackedFunc* f = runtime::Registry::Get(load_f_name); @@ -80,9 +92,6 @@ tvm::runtime::Module deserialize(SerializationType state) { tvm::runtime::Module ret = (*f)(file_name, ""); - ICHECK(remove(file_name.c_str()) == 0) - << "remove temporary file (" << file_name << ") unsuccessfully"; - return ret; } From 9cc0f75b7c08689c20d0b4424aaa84f4c46477d6 Mon Sep 17 00:00:00 2001 From: juda Date: Sun, 26 Jun 2022 18:59:22 -0700 Subject: [PATCH 11/27] optimize deleter --- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index ef5a96dcde7e..5ca6307b00b0 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -58,15 +58,14 @@ SerializationType serialize(tvm::runtime::Module module) { return (*f_to_str)(module); } -struct Deleter { // deleter - Deleter(std::string file_name) { - this->file_name = file_name; - }; - void operator()(FILE* p) const { - ICHECK(remove(file_name.c_str()) == 0) +struct Deleter { // deleter + Deleter(std::string file_name) { this->file_name = file_name; }; + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) << "remove temporary file (" << file_name << ") unsuccessfully"; - }; - std::string file_name; + }; + std::string file_name; }; tvm::runtime::Module deserialize(SerializationType state) { @@ -78,10 +77,9 @@ tvm::runtime::Module deserialize(SerializationType state) { const std::string name = tmpnam(NULL); auto file_name = name + ".so"; - std::unique_ptr pFile( - fopen(file_name.c_str(), "wb"), Deleter(file_name)); + std::unique_ptr pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); fwrite(bytes, sizeof(u_char), length, pFile.get()); - fclose(pFile.get()); + fflush(pFile.get()); std::string load_f_name = "runtime.module.loadfile_so"; const PackedFunc* f = runtime::Registry::Get(load_f_name); @@ -89,7 +87,7 @@ tvm::runtime::Module deserialize(SerializationType state) { << " resolved to (" << load_f_name << ") in the global registry." << "Ensure that you have loaded the correct runtime code, and" << "that you are on the correct hardware architecture."; - + tvm::runtime::Module ret = (*f)(file_name, ""); return ret; From 64e48a7adde49066a74390e0bf0764e0f6eab258 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 28 Jun 2022 07:27:23 -0700 Subject: [PATCH 12/27] drop how-to guides --- gallery/how_to/work_with_pytorch/as_torch.py | 151 --------------- .../work_with_pytorch/optimize_torch.py | 177 ------------------ 2 files changed, 328 deletions(-) delete mode 100644 gallery/how_to/work_with_pytorch/as_torch.py delete mode 100644 gallery/how_to/work_with_pytorch/optimize_torch.py diff --git a/gallery/how_to/work_with_pytorch/as_torch.py b/gallery/how_to/work_with_pytorch/as_torch.py deleted file mode 100644 index 04b317c296d8..000000000000 --- a/gallery/how_to/work_with_pytorch/as_torch.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Wrap Your Tensor IR with PyTorch Module -====================== -**Author**: `Yaoda Zhou `_ -This article is an introductory tutorial to wrap the Tensor IR code with PyTorch module. -By the decorator `as_torch`, users are able to import a Tensor IR code in PyTorch with a low cost. -For us to follow this tutorial, PyTorch as well as TorchVision should be installed. -For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as -.. code-block:: bash - conda install -c conda-forge pytorch-gpu - python -c "import torch;print(torch.compiled_with_cxx11_abi())" -If the output printed in terminal is `True` then PyTorch built with Cxx11 ABI is successfully installed. -""" -# Import Tvm and PyTorch, as well as necessary libraries -import tvm -import tvm.testing -import numpy as np -import torch -import torch.nn -from tvm.contrib.torch import as_torch -from tvm.script import tir as T - - -###################################################################### -# Define an example of vector add -# (This example could be found at https://tvm.apache.org/docs/tutorial/tensor_ir_blitz_course.html) -# ------------------------------- -# Our `as_torch` is a simple decorator: put it on any Tensor IR code and it will convert it into PyTorch module automatically. -@as_torch -@tvm.script.ir_module -class MyModule: - @T.prim_func - def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, (8,), dtype="float32") - B = T.match_buffer(b, (8,), dtype="float32") - for i in range(8): - with T.block("B"): - vi = T.axis.spatial(8, i) - B[vi] = A[vi] + 1.0 - -###################################################################### -# Write a test case: Tvm's testing is used to compare two tensors -# ------------------------------- - - -# Define two torch tensors -q1 = torch.arange(8).type(torch.float32) -q2 = torch.zeros((8,), dtype=torch.float32) - -# Call the function directly, the result is stored at `q2` -MyModule(q1, q2) - -# Testing. No output implies that tensors are equal -tvm.testing.assert_allclose(q2.numpy(), (q1+1).numpy(), atol=1e-5, rtol=1e-5) - - -###################################################################### -# The second example: matrix multiplication with a limit form meta-programming -# ------------------------------- -# As above, we can add `as_torch` decorator to a Tensor IR function. - - -@as_torch -def matmul(M: int, N: int, K: int, dtype: str): - @T.prim_func - def f(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [M, K], dtype=dtype) - B = T.match_buffer(b, [N, K], dtype=dtype) - C = T.match_buffer(c, [M, N], dtype=dtype) - for i, j, k in T.grid(M, N, K): - with T.block(): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - return f - -###################################################################### -# Test case for `matmul` function. -# ------------------------------- - - -# Create two 128 x 128 matrixes as input -s1 = np.random.rand(128, 128).astype("float32") -s2 = np.random.rand(128, 128).astype("float32") -s3 = np.zeros((128, 128)).astype("float32") - -q1 = torch.from_numpy(s1) -q2 = torch.from_numpy(s2) -q3 = torch.from_numpy(s3) - -# Result from numpy -numpy_result = np.matmul(s1, np.transpose(s2)) - -# Instantiate the `matmul` function by passing the parameters of shapes and datatype -tvm_module = matmul(128, 128, 128, "float32") - -# Run the operator -tvm_module(q1, q2, q3) - -tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) - -###################################################################### -# Last example: GPU supporting -# ------------------------------- -# In such an example, we demonstrate our method does support module built upon GPU -# The code below is the GPU version of `MyModule` - - -@as_torch -@tvm.script.ir_module -class ModuleGPU: - @T.prim_func - def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - for i_0 in T.thread_binding(2, thread="blockIdx.x"): - for i_2 in T.thread_binding(2, thread="threadIdx.x"): - for i_1 in T.serial(2): - with T.block("B"): - vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) - T.reads(A[vi]) - T.writes(B[vi]) - B[vi] = A[vi] + T.float32(1) - - -# Define two torch tensors, on GPU -cuda0 = torch.device('cuda:0') -q1 = torch.arange(8, device=cuda0).type(torch.float32) -q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) - -# Running and testing -ModuleGPU(q1, q2) -tvm.testing.assert_allclose( - q2.cpu().numpy(), (q1+1).cpu().numpy(), atol=1e-5, rtol=1e-5) diff --git a/gallery/how_to/work_with_pytorch/optimize_torch.py b/gallery/how_to/work_with_pytorch/optimize_torch.py deleted file mode 100644 index dfcd982bcc83..000000000000 --- a/gallery/how_to/work_with_pytorch/optimize_torch.py +++ /dev/null @@ -1,177 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Compile PyTorch Models -====================== -**Author**: `Yaoda Zhou `_ -This article is an introductory tutorial to optimize PyTorch models by using `tvm.contrib.torch.optimize_torch`. -For us to follow this tutorial, PyTorch as well as TorchVision should be installed. -For avoiding potential "undefined symbol" issue, we strongly recommend to install PyTorch built with Cxx11 ABI from Conda, as -.. code-block:: bash - conda install -c conda-forge pytorch-gpu - python -c "import torch;print(torch.compiled_with_cxx11_abi())" -If the output printed in terminal is `True` then PyTorch built with Cxx11 ABI is successfully installed. -""" -# Import TVM -import tvm -import tvm.testing -# Import `optimize_torch` function -from tvm.contrib.torch import optimize_torch -from tvm.meta_schedule import TuneConfig - -# Import PyTorch -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision.models import resnet18 -# Import library for profiling -import torch.utils.benchmark as benchmark - - -###################################################################### -# Define a simple module written by PyTorch -# ------------------------------ - -class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 20, 5) - self.conv2 = nn.Conv2d(20, 20, 5) - - def forward(self, x): - x = F.relu(self.conv1(x)) - return F.relu(self.conv2(x)) - -###################################################################### -# Optimized SimpleModel by TVM MetaSchedule -# ------------------------------ -# We provide a `optimize_torch` function, which have a similar usage as `torch.jit.trace`. -# For the function, we have five parameters need to provide. -# If the third parameter `tuning_config` is not provided, a default configuration is loaded. -# If the parameter `target` is empty, the model will deploy on CPU. - - -example_input = torch.randn(20, 1, 10, 10) - -# We use default configuration for the first example -model_optimized_by_meta = optimize_torch( - SimpleModel(), example_input) - -###################################################################### -# Save/Load module -# ------------------------------ -# We can save and load our tuned module like standard nn.module - -# Let us run our tuned module and see the result -ret1 = model_optimized_by_meta(example_input) - -torch.save(model_optimized_by_meta, "meta_model.pt") -model_loaded = torch.load("meta_model.pt") - -# We load the module and run again and it will return the same result as above. -ret2 = model_loaded(example_input) - -tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) - -###################################################################### -# Define the resnet18 optimized by MetaSchedule -# ------------------------------ -# Another example, we compare the two optimizers about the performance of resnet18 -# For learning how to define a resnet18 model via PyTorch's nn.Module, -# you can refer to https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting - -# In our working machine, the GPU model is nvidia/geforce-rtx-3070. -target_cuda = "nvidia/geforce-rtx-3070" - -# We can define the configuration by ourselves -tuning_config = TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=20000, - max_trials_global=20000, -) - -# For PyTorch users, you can write your nn.Module in a normal way. -# By applying "optimize_torch" function on the resnet18 model, we obtain a new resnet18 model optimized by MetaSchedule - - -class MyResNet18(torch.nn.Module): - def __init__(self, config, target=None): - super(MyResNet18, self).__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)).cuda() - self.resnet = optimize_torch( - resnet18(), [torch.rand(1, 3, 224, 224)], config, target) - - def forward(self, input): - return self.resnet(input - self.means) - - -# Since the setting of the number of trials is large, the initialization could be slow (sometimes more than 3 hours!) -meta_module_resnet18 = MyResNet18(tuning_config, target_cuda) - - -###################################################################### -# Define the resnet18 optimized by TorchScript -# ------------------------------ -# Besides, let us define a resnet18 model in a standard way. -# TorchScript also provide a built-in "optimize_for_inference" function to accelerate the inference. - -class JitModule(torch.nn.Module): - def __init__(self): - super(JitModule, self).__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)).cuda() - self.resnet = torch.jit.optimize_for_inference( - torch.jit.script(resnet18().cuda().eval())) - - def forward(self, input): - return self.resnet(input - self.means) - - -jit_module_resnet18 = JitModule() - -###################################################################### -# Compare the performance between two scheduling approaches. -# ------------------------------ -# Using PyTorch's benchmark Compare class, we can have a straightforward comparison between two inference models. - -results = [] -for i in range(20): - test_input = torch.rand(1, 3, 224, 224).half().cuda() - sub_label = f'[test {i}]' - results.append(benchmark.Timer( - stmt='meta_module_resnet18(test_input)', - setup='from __main__ import meta_module_resnet18', - globals={'test_input': test_input}, - sub_label=sub_label, - description='tuning by meta', - ).blocked_autorange()) - results.append(benchmark.Timer( - stmt='jit_module_resnet18(test_input)', - setup='from __main__ import jit_module_resnet18', - globals={'test_input': test_input}, - sub_label=sub_label, - description='tuning by jit', - ).blocked_autorange()) - -# We can print the results on screen. -compare = benchmark.Compare(results) -compare.print() - -# As above, we can save the module for future use -torch.save(meta_module_resnet18, "meta_tuned_resnet18.pt") From 66911190eddac0b090c87708d7d9f34bb9ffab44 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 30 Jun 2022 01:56:42 -0700 Subject: [PATCH 13/27] clang-format-10 --- src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 5ca6307b00b0..fb82261de6d1 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -87,7 +87,7 @@ tvm::runtime::Module deserialize(SerializationType state) { << " resolved to (" << load_f_name << ") in the global registry." << "Ensure that you have loaded the correct runtime code, and" << "that you are on the correct hardware architecture."; - + tvm::runtime::Module ret = (*f)(file_name, ""); return ret; From 0dc4f4253226cc35713f026ddba6600e541f274f Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 30 Jun 2022 19:35:38 -0700 Subject: [PATCH 14/27] formatter changes --- apps/pt_tvmdsoop/tests/test_as_torch.py | 12 ++- apps/pt_tvmdsoop/tests/test_optimize_torch.py | 95 +++++++++++-------- python/tvm/contrib/torch/__init__.py | 3 +- python/tvm/contrib/torch/as_torch.py | 27 ++++-- python/tvm/contrib/torch/optimize_torch.py | 15 ++- 5 files changed, 86 insertions(+), 66 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index b84d9fbc030a..0402605672a5 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -39,6 +39,7 @@ def f(a: T.handle, b: T.handle, c: T.handle) -> None: with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + return f @@ -59,6 +60,7 @@ def main(a: T.handle, b: T.handle): vi = T.axis.spatial(8, i) B[vi] = A[vi] + 1.0 + @as_torch @tvm.script.ir_module class ModuleGPU: @@ -74,6 +76,7 @@ def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: T.writes(B[vi]) B[vi] = A[vi] + T.float32(1) + class MinuesOnes(torch.nn.Module): def __init__(self): super(MinuesOnes, self).__init__() @@ -110,16 +113,17 @@ def test_tvmscript_torch_decorator(): MyModule(q1, q2) - tvm.testing.assert_allclose(q2.numpy(), (q1+1).numpy(), atol=1e-5, rtol=1e-5) - + tvm.testing.assert_allclose(q2.numpy(), (q1 + 1).numpy(), atol=1e-5, rtol=1e-5) + + def test_tvmscript_torch_gpu(): - cuda0 = torch.device('cuda:0') + cuda0 = torch.device("cuda:0") q1 = torch.arange(8, device=cuda0).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) ModuleGPU(q1, q2) - tvm.testing.assert_allclose(q2.cpu().numpy(), (q1+1).cpu().numpy(), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5) def test_torch_with_tvmscript(): diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py index 5c68233d654d..02dc96701d52 100644 --- a/apps/pt_tvmdsoop/tests/test_optimize_torch.py +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -21,7 +21,7 @@ import tempfile import torch -import torch.utils.benchmark as benchmark +from torch.utils import benchmark from torchvision.models import resnet18 import tvm @@ -31,35 +31,37 @@ # default config for testing config = TuneConfig( - strategy="evolutionary", - num_trials_per_iter=4, - max_trials_per_task=8, - max_trials_global=16, - ) + strategy="evolutionary", + num_trials_per_iter=4, + max_trials_per_task=8, + max_trials_global=16, +) + def test_matmul_tuning_relay(): def matmul(x, w): return torch.matmul(x, w) - + x = torch.randn(15, 20) w = torch.randn(20, 30) example_inputs = (x, w) - - rt_mod = optimize_torch(matmul, example_inputs) + rt_mod = optimize_torch(matmul, example_inputs) torch_answer = torch.matmul(x, w).numpy() tvm_answer = rt_mod(x, w).numpy() - + tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5) - + + class InnerModel(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(1, 20, 5) - + def forward(self, x): return torch.nn.functional.relu(self.conv(x)) + class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() @@ -69,16 +71,17 @@ def __init__(self): def forward(self, x): x = self.relu(x) return torch.nn.functional.relu(self.conv(x)) - + + def test_nested_module(): simple_module = SimpleModel() example_input = torch.randn(20, 1, 10, 10) - optimized_module = optimize_torch( - simple_module, example_input) + optimized_module = optimize_torch(simple_module, example_input) ret1 = simple_module(example_input).detach().numpy() ret2 = optimized_module(example_input).detach().numpy() tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) - + + def test_save_load_function(): def foo(x): return 2 * x + 1 @@ -91,60 +94,68 @@ def foo(x): loaded_mod = torch.load(tmp.name) ret2 = loaded_mod(example_input) tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) - + + class MyResNet18(torch.nn.Module): def __init__(self, config, target=None): super(MyResNet18, self).__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)).cuda() - self.resnet = optimize_torch( - resnet18(), [torch.rand(1, 3, 224, 224)], config, target) + self.means = torch.nn.Parameter( + torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1) + ).cuda() + self.resnet = optimize_torch(resnet18(), [torch.rand(1, 3, 224, 224)], config, target) def forward(self, input): return self.resnet(input - self.means) + class JitModule(torch.nn.Module): def __init__(self): super(JitModule, self).__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)).cuda() - self.resnet = torch.jit.optimize_for_inference( - torch.jit.script(resnet18().cuda().eval())) + self.means = torch.nn.Parameter( + torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1) + ).cuda() + self.resnet = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval())) def forward(self, input): return self.resnet(input - self.means) + if torch.cuda.is_available(): target_cuda = "nvidia/geforce-rtx-3070" meta_module_resnet18 = MyResNet18(config, target_cuda) jit_module_resnet18 = JitModule() + def compare_optimize_resnet18_to_torchscript(): results = [] for i in range(20): test_input = torch.rand(1, 3, 224, 224).half().cuda() - sub_label = f'[test {i}]' - results.append(benchmark.Timer( - stmt='meta_module_resnet18(test_input)', - setup='from __main__ import meta_module_resnet18', - globals={'test_input': test_input}, - sub_label=sub_label, - description='tuning by meta', - ).blocked_autorange()) - results.append(benchmark.Timer( - stmt='jit_module_resnet18(test_input)', - setup='from __main__ import jit_module_resnet18', - globals={'test_input': test_input}, - sub_label=sub_label, - description='tuning by jit', - ).blocked_autorange()) + sub_label = f"[test {i}]" + results.append( + benchmark.Timer( + stmt="meta_module_resnet18(test_input)", + setup="from __main__ import meta_module_resnet18", + globals={"test_input": test_input}, + sub_label=sub_label, + description="tuning by meta", + ).blocked_autorange() + ) + results.append( + benchmark.Timer( + stmt="jit_module_resnet18(test_input)", + setup="from __main__ import jit_module_resnet18", + globals={"test_input": test_input}, + sub_label=sub_label, + description="tuning by jit", + ).blocked_autorange() + ) compare = benchmark.Compare(results) compare.print() + if __name__ == "__main__": - test_matmul_tuning_relay() + test_matmul_tuning_relay() test_nested_module() test_save_load_function() if torch.cuda.is_available(): compare_optimize_resnet18_to_torchscript() - \ No newline at end of file diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 1911c8388239..340f9cef9e58 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -21,6 +21,7 @@ import torch from tvm._ffi import libinfo + def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): system = platform.system() if system == "Darwin": @@ -57,4 +58,4 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): from . import optimize_torch GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper -optimize_torch = optimize_torch.optimize_torch \ No newline at end of file +optimize_torch = optimize_torch.optimize_torch diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 49e2be6b390e..c445b642ac53 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -26,13 +26,20 @@ class OperatorModuleWrapper(torch.nn.Module): - def __init__(self, module: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, tvm.contrib.graph_executor.GraphModule]): + def __init__( + self, + module: Union[ + tvm.ir.module.IRModule, + tvm.tir.function.PrimFunc, + tvm.contrib.graph_executor.GraphModule, + ], + ): super().__init__() - self.rt_module = None # runtime module - self.ir_module = module # IR moudle + self.rt_module = None # runtime module + self.ir_module = module # IR moudle - def build(self, target = None): - runtime_module = tvm.build(self.ir_module, target = target) + def build(self, target=None): + runtime_module = tvm.build(self.ir_module, target=target) func = tvm.get_global_func("tvmtorch.save_runtime_mod") func(runtime_module) @@ -41,18 +48,16 @@ def build(self, target = None): def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: if self.rt_module is None: if torch_inputs[0].is_cuda: - self.build(target = "cuda") + self.build(target="cuda") elif torch_inputs[0].device.type == "cpu": self.build() else: raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet") - + return self.rt_module.forward(torch_inputs) -def as_torch( - func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] -): +def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]): """A decorator of converting TensorIR to PyTorch nn.Module. Parameters @@ -69,6 +74,8 @@ def as_torch( if isinstance(func, tvm.ir.module.IRModule) or isinstance(func, tvm.tir.function.PrimFunc): return OperatorModuleWrapper(func) elif isinstance(func, Callable): + def func_get_param(*args, **kargs): return OperatorModuleWrapper(func(*args, **kargs)) + return func_get_param diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index ca461af3da6b..575f28a33db8 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -33,10 +33,7 @@ class GraphExecutorFactoryWrapper(torch.nn.Module): - def __init__( - self, - module: tvm.runtime.Module - ): + def __init__(self, module: tvm.runtime.Module): super().__init__() self.inner_module = module @@ -71,10 +68,10 @@ def optimize_torch( Parameters ---------- - func : callable or torch.nn.Module + func : callable or torch.nn.Module A Python function or nn.Module that could run by TorchScript's trace. (ie: torch.jit.trace(model, input)) - example_inputs : tuple or torch.Tensor + example_inputs : tuple or torch.Tensor A tuple of example inputs that will run together with `func` by providing the shape information. @@ -118,8 +115,7 @@ def optimize_torch( if isinstance(example_inputs, torch.Tensor): example_inputs = [example_inputs] - shape_list = [(f"inp_{idx}", i.shape) - for idx, i in enumerate(example_inputs)] + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule if work_dir: cm = contextlib.nullcontext(work_dir) @@ -127,7 +123,8 @@ def optimize_torch( cm = tempfile.TemporaryDirectory() with cm as work_dir_path: executor_factory = tune_relay( - mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path) + mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path + ) save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") save_runtime_mod(executor_factory.module) From 79591ed2d391529181b3446a49fa99d6176620bf Mon Sep 17 00:00:00 2001 From: juda Date: Fri, 1 Jul 2022 01:46:00 -0700 Subject: [PATCH 15/27] reformat --- apps/pt_tvmdsoop/tests/test_as_torch.py | 14 ++++++------ python/tvm/contrib/torch/as_torch.py | 12 +++++++---- python/tvm/contrib/torch/optimize_torch.py | 25 ++++++++++++++-------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 0402605672a5..e2eb6f1e0291 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -17,19 +17,21 @@ # specific language governing permissions and limitations # under the License. """Test script for tvm torch module""" -import tvm -import torch -from tvm.contrib.torch import as_torch -from tvm.script import tir as T import numpy as np + +import torch import torch.nn + +import tvm import tvm.testing +from tvm.contrib.torch import as_torch +from tvm.script import tir as T @as_torch def matmul(M: int, N: int, K: int, dtype: str): @T.prim_func - def f(a: T.handle, b: T.handle, c: T.handle) -> None: + def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [M, K], dtype=dtype) B = T.match_buffer(b, [N, K], dtype=dtype) C = T.match_buffer(c, [M, N], dtype=dtype) @@ -40,7 +42,7 @@ def f(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - return f + return main @as_torch diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index c445b642ac53..c9cec6bb4dfd 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -17,14 +17,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Tuple, Union +""" +as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`. +""" +from typing import Callable, List, Union import torch import torch.utils.dlpack import tvm - +# python wrapper for OperatorModule class OperatorModuleWrapper(torch.nn.Module): def __init__( self, @@ -69,9 +72,10 @@ def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Call Returns ------- mod : OperatorModuleWrapper - It will return an object of OperatorModuleWrapper, which is the subclass of the original nn.Module. + It will return an object of OperatorModuleWrapper, + which is the subclass of the original nn.Module. """ - if isinstance(func, tvm.ir.module.IRModule) or isinstance(func, tvm.tir.function.PrimFunc): + if isinstance(func, tvm.ir.module.IRModule, tvm.tir.function.PrimFunc): return OperatorModuleWrapper(func) elif isinstance(func, Callable): diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 575f28a33db8..50bad2118606 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -17,11 +17,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +''' +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +''' import base64 import contextlib import tempfile -from typing import Callable, Dict, Tuple, Union, List +from typing import Tuple from tvm.meta_schedule.tune import tune_relay import torch import torch.utils.dlpack @@ -31,7 +35,7 @@ from tvm._ffi import get_global_func, register_func from tvm.meta_schedule import TuneConfig - +# The python wrapper for GraphExecutorFactory class GraphExecutorFactoryWrapper(torch.nn.Module): def __init__(self, module: tvm.runtime.Module): super().__init__() @@ -69,7 +73,8 @@ def optimize_torch( Parameters ---------- func : callable or torch.nn.Module - A Python function or nn.Module that could run by TorchScript's trace. (ie: torch.jit.trace(model, input)) + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) example_inputs : tuple or torch.Tensor A tuple of example inputs that @@ -77,7 +82,8 @@ def optimize_torch( tuning_config : tvm.meta_schedule.TuneConfig The configuration of tuning by MetaSchedule. - We suggest users to provide their own setting, otherwise by default setting a tuning process could be very slow, + We suggest users to provide their own setting, + otherwise by default setting a tuning process could be very slow, sometimes costs a few hours. target : Optional[Union[str, Target]] @@ -90,7 +96,8 @@ def optimize_torch( Returns ------- mod : GraphExecutorFactoryWrapper - It will return an object of GraphExecutorFactoryWrapper, which is the subclass of the original nn.Module. + It will return an object of GraphExecutorFactoryWrapper, + which is the subclass of the original nn.Module. """ if target: @@ -118,10 +125,10 @@ def optimize_torch( shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule if work_dir: - cm = contextlib.nullcontext(work_dir) + context_manager = contextlib.nullcontext(work_dir) else: - cm = tempfile.TemporaryDirectory() - with cm as work_dir_path: + context_manager = tempfile.TemporaryDirectory() + with context_manager as work_dir_path: executor_factory = tune_relay( mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path ) From f8a63d33bcd179cb7fc9053fc66ca25ca2ddeb70 Mon Sep 17 00:00:00 2001 From: juda Date: Fri, 1 Jul 2022 05:47:45 -0700 Subject: [PATCH 16/27] reformat --- python/tvm/contrib/torch/optimize_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 50bad2118606..8f3587d170f6 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -17,11 +17,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -''' +""" optimize_torch: aa function similar to `torch.jit.trace`, which is used to optimize the `torch.nn.module` by TVM metaSchedule, and returns a custom TorchScript operator -''' +""" import base64 import contextlib import tempfile From 4e2158eba34e263078ee12f9d8012a93ce007b18 Mon Sep 17 00:00:00 2001 From: juda Date: Fri, 1 Jul 2022 06:06:01 -0700 Subject: [PATCH 17/27] reformat --- python/tvm/contrib/torch/as_torch.py | 12 ++++++++---- python/tvm/contrib/torch/optimize_torch.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index c9cec6bb4dfd..264f3be2a5bf 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -17,7 +17,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`. """ from typing import Callable, List, Union @@ -27,6 +30,7 @@ import tvm + # python wrapper for OperatorModule class OperatorModuleWrapper(torch.nn.Module): def __init__( @@ -39,11 +43,11 @@ def __init__( ): super().__init__() self.rt_module = None # runtime module - self.ir_module = module # IR moudle + self.ir_module = module # IR modules def build(self, target=None): runtime_module = tvm.build(self.ir_module, target=target) - func = tvm.get_global_func("tvmtorch.save_runtime_mod") + func = tvm.get_global_func("PyTorch.save_runtime_mod") func(runtime_module) self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper() @@ -77,7 +81,7 @@ def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Call """ if isinstance(func, tvm.ir.module.IRModule, tvm.tir.function.PrimFunc): return OperatorModuleWrapper(func) - elif isinstance(func, Callable): + if isinstance(func, Callable): def func_get_param(*args, **kargs): return OperatorModuleWrapper(func(*args, **kargs)) diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 8f3587d170f6..1ad72a657c1c 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -17,8 +17,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring """ -optimize_torch: aa function similar to `torch.jit.trace`, +optimize_torch: aa function similar to `torch.jit.trace`, which is used to optimize the `torch.nn.module` by TVM metaSchedule, and returns a custom TorchScript operator """ @@ -26,7 +29,7 @@ import contextlib import tempfile from typing import Tuple -from tvm.meta_schedule.tune import tune_relay + import torch import torch.utils.dlpack @@ -34,6 +37,8 @@ from tvm import relay from tvm._ffi import get_global_func, register_func from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + # The python wrapper for GraphExecutorFactory class GraphExecutorFactoryWrapper(torch.nn.Module): @@ -45,8 +50,7 @@ def forward(self, *torch_inputs: Tuple[torch.Tensor]): ret = self.inner_module.forward(torch_inputs) if len(ret) == 1: return ret[0] - else: - return ret + return ret def llvm_target(): From 8584be25fc403cf1b8ce8031fa200c0fc1f5f52a Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 5 Jul 2022 03:58:31 -0700 Subject: [PATCH 18/27] reformatting --- python/tvm/contrib/torch/as_torch.py | 4 ++-- src/contrib/torch/base64.h | 11 +++++----- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 21 ++++++++++--------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 264f3be2a5bf..bb7d4c2bb802 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -47,7 +47,7 @@ def __init__( def build(self, target=None): runtime_module = tvm.build(self.ir_module, target=target) - func = tvm.get_global_func("PyTorch.save_runtime_mod") + func = tvm.get_global_func("tvmtorch.save_runtime_mod") func(runtime_module) self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper() @@ -79,7 +79,7 @@ def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Call It will return an object of OperatorModuleWrapper, which is the subclass of the original nn.Module. """ - if isinstance(func, tvm.ir.module.IRModule, tvm.tir.function.PrimFunc): + if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): return OperatorModuleWrapper(func) if isinstance(func, Callable): diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h index d949db572c52..859fd1abcfd0 100644 --- a/src/contrib/torch/base64.h +++ b/src/contrib/torch/base64.h @@ -22,8 +22,8 @@ * \brief Util functions for converting plain bytes back to plain bytes */ -#ifndef TVM_SUPPORT_BASE64_RT_H_ -#define TVM_SUPPORT_BASE64_RT_H_ +#ifndef TVM_CONTRIB_TORCH_BASE64_H_ +#define TVM_CONTRIB_TORCH_BASE64_H_ #include @@ -36,7 +36,7 @@ namespace tvm { namespace support { -size_t b64strlen(std::string& b64str) { +size_t b64strlen(const std::string b64str) { ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; size_t length = b64str.size() / 4 * 3; if (b64str[b64str.size() - 2] == '=') { @@ -47,7 +47,7 @@ size_t b64strlen(std::string& b64str) { return length; } -void b64decode(std::string& b64str, u_char* ret) { +void b64decode(const std::string b64str, u_char* ret) { size_t index = 0; const auto length = b64str.size(); for (size_t i = 0; i < length; i += 4) { @@ -66,11 +66,10 @@ void b64decode(std::string& b64str, u_char* ret) { } } } - ret[index] = '\0'; ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; } } // namespace support } // namespace tvm -#endif // TVM_SUPPORT_BASE64_RT_H_ \ No newline at end of file +#endif // TVM_CONTRIB_TORCH_BASE64_H_ diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index fb82261de6d1..c17ad8780e6e 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -59,26 +59,25 @@ SerializationType serialize(tvm::runtime::Module module) { } struct Deleter { // deleter - Deleter(std::string file_name) { this->file_name = file_name; }; + explicit Deleter(std::string file_name) { this->file_name = file_name; } void operator()(FILE* p) const { fclose(p); ICHECK(remove(file_name.c_str()) == 0) << "remove temporary file (" << file_name << ") unsuccessfully"; - }; + } std::string file_name; }; tvm::runtime::Module deserialize(SerializationType state) { auto length = tvm::support::b64strlen(state); - u_char bytes[length]; - memset(bytes, 0, sizeof(bytes)); - tvm::support::b64decode(state, bytes); + std::vector bytes(length); + tvm::support::b64decode(state, bytes.data()); const std::string name = tmpnam(NULL); auto file_name = name + ".so"; std::unique_ptr pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); - fwrite(bytes, sizeof(u_char), length, pFile.get()); + fwrite(bytes.data(), sizeof(u_char), length, pFile.get()); fflush(pFile.get()); std::string load_f_name = "runtime.module.loadfile_so"; @@ -128,7 +127,7 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { SerializationType Serialize() { return serialize(runtime_module); } - OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } + explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } private: tvm::runtime::Module runtime_module; @@ -165,7 +164,7 @@ tvm::Device getDevice(const at::Tensor& tensor) { */ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { public: - GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) + explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) : executor_factory_(executor_factory) { CHECK(executor_factory_->IsInstance()) << "module is not an instance of GraphExecutorFactory"; @@ -220,7 +219,9 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { SerializationType Serialize() { return serialize(executor_factory_); } - GraphExecutorFactoryWrapper(SerializationType state) { executor_factory_ = deserialize(state); } + explicit GraphExecutorFactoryWrapper(SerializationType state) { + executor_factory_ = deserialize(state); + } private: tvm::runtime::Module executor_factory_; @@ -258,4 +259,4 @@ TORCH_LIBRARY(tvm_tuning, m) { } } // namespace contrib -} // namespace tvm \ No newline at end of file +} // namespace tvm From 79570e85b814793c6fc8f97d593c6d04f9fc8aa2 Mon Sep 17 00:00:00 2001 From: juda Date: Mon, 11 Jul 2022 19:08:09 -0700 Subject: [PATCH 19/27] fixed --- apps/pt_tvmdsoop/tests/test_as_torch.py | 10 ++++------ apps/pt_tvmdsoop/tests/test_optimize_torch.py | 16 ++++++++-------- python/tvm/contrib/torch/as_torch.py | 1 - python/tvm/contrib/torch/optimize_torch.py | 13 ++++--------- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 4 ++-- 5 files changed, 18 insertions(+), 26 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index e2eb6f1e0291..681453216248 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -90,17 +90,15 @@ def forward(self, *input): def test_tvmscript_torch_matmul(): - s1 = np.ones((128, 128)).astype("float32") - s2 = np.ones((128, 128)).astype("float32") - s3 = np.zeros((128, 128)).astype("float32") - s1[0, 0] = 0 - s2[4, 4] = 0 + s1 = np.random.rand(128, 128).astype("float32") + s2 = np.random.rand(128, 128).astype("float32") + s3 = np.random.rand(128, 128).astype("float32") q1 = torch.from_numpy(s1) q2 = torch.from_numpy(s2) q3 = torch.from_numpy(s3) - numpy_result = np.matmul(s1, s2) + numpy_result = np.matmul(s1, np.transpose(s2)) nn_module = matmul(128, 128, 128, "float32") diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py index 02dc96701d52..258dfe55c43f 100644 --- a/apps/pt_tvmdsoop/tests/test_optimize_torch.py +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -29,14 +29,6 @@ from tvm.contrib.torch import optimize_torch from tvm.meta_schedule import TuneConfig -# default config for testing -config = TuneConfig( - strategy="evolutionary", - num_trials_per_iter=4, - max_trials_per_task=8, - max_trials_global=16, -) - def test_matmul_tuning_relay(): def matmul(x, w): @@ -120,6 +112,14 @@ def forward(self, input): return self.resnet(input - self.means) +# default config for testing +config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=4, + max_trials_per_task=8, + max_trials_global=16, +) + if torch.cuda.is_available(): target_cuda = "nvidia/geforce-rtx-3070" meta_module_resnet18 = MyResNet18(config, target_cuda) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index bb7d4c2bb802..09bd4679e498 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -38,7 +38,6 @@ def __init__( module: Union[ tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, - tvm.contrib.graph_executor.GraphModule, ], ): super().__init__() diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 1ad72a657c1c..fa77f9d7d6b0 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -81,8 +81,7 @@ def optimize_torch( (ie: torch.jit.trace(model, input)) example_inputs : tuple or torch.Tensor - A tuple of example inputs that - will run together with `func` by providing the shape information. + Inputs to `torch.jit.trace`. tuning_config : tvm.meta_schedule.TuneConfig The configuration of tuning by MetaSchedule. @@ -92,7 +91,7 @@ def optimize_torch( target : Optional[Union[str, Target]] The target of the compilation. - If user doesn't set the target, the module is built upon the LLVM. + If user doesn't set the target, the module will be built for the CPU target. work_dir : Optional[str] The working directory to save intermediate results. @@ -104,14 +103,10 @@ def optimize_torch( which is the subclass of the original nn.Module. """ - if target: - pass - else: + if target is None: target = llvm_target() - if tuning_config: - pass - else: + if tuning_config is None: # Default setting. For a better tuning result the number could be set large. tuning_config = TuneConfig( strategy="evolutionary", diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index c17ad8780e6e..7d7e6a463036 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -54,7 +54,7 @@ SerializationType serialize(tvm::runtime::Module module) { static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("script_torch.save_to_base64"); ICHECK(f_to_str) << "IndexError: Cannot find the packed function " - "`script_torch.save_to_tar` in the global registry"; + "`script_torch.save_to_base64` in the global registry"; return (*f_to_str)(module); } @@ -63,7 +63,7 @@ struct Deleter { // deleter void operator()(FILE* p) const { fclose(p); ICHECK(remove(file_name.c_str()) == 0) - << "remove temporary file (" << file_name << ") unsuccessfully"; + << "Failed to remove temporary file (" << file_name << ")"; } std::string file_name; }; From 1218044638e967203a6cc142471db7c3430d37f6 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 13 Jul 2022 00:34:03 -0700 Subject: [PATCH 20/27] auto setting --- python/tvm/contrib/torch/optimize_torch.py | 83 +++++++++++++++++----- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index fa77f9d7d6b0..13574677e07b 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -28,7 +28,7 @@ import base64 import contextlib import tempfile -from typing import Tuple +from typing import Dict, Optional, Tuple, Union import torch import torch.utils.dlpack @@ -36,8 +36,17 @@ import tvm from tvm import relay from tvm._ffi import get_global_func, register_func -from tvm.meta_schedule import TuneConfig -from tvm.meta_schedule.tune import tune_relay +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.meta_schedule import TuneConfig, default_config +from tvm.meta_schedule.apply_history_best import ApplyHistoryBest +from tvm.meta_schedule.relay_integration import extract_task_from_relay +from tvm.meta_schedule.tune import tune_extracted_tasks +from tvm.meta_schedule.utils import autotvm_silencer +from tvm.runtime import vm +from tvm.runtime.module import Module +from tvm.runtime.ndarray import NDArray +from tvm.target.target import Target # The python wrapper for GraphExecutorFactory @@ -65,6 +74,56 @@ def save_to_base64(obj) -> bytes: return base64.b64encode(tfile.read()) +def tune_relay_auto( + mod: IRModule, + target: Union[str, Target], + config: TuneConfig, + work_dir: str, + backend: str = "graph", + params: Optional[Dict[str, NDArray]] = None, +) -> Union[Module, vm.Executable]: + """A wrapper of `tune_relay` but provide a default setting for the config. + + Parameters + ---------- + mod : IRModule + The module to tune. + target : Union[str, Target] + The target to tune for. + config : TuneConfig + The search strategy config. + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + work_dir : Optional[str] + The working directory to save intermediate results. + backend : str = "graph" + The backend to use for relay compilation(graph / vm). + + Returns + ------- + lib : Union[Module, tvm.runtime.vm.Executable] + The built runtime module or vm Executable for the given relay workload. + """ + target = default_config.target(target) + extracted_tasks = extract_task_from_relay(mod, target, params) + if config is None: + config = TuneConfig( + num_trials_per_iter=16, + max_trials_global=16 * len(extracted_tasks), + ) + database = tune_extracted_tasks(extracted_tasks, config, work_dir) + relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend] + with target, autotvm_silencer(), ApplyHistoryBest(database): + with PassContext( + opt_level=3, + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", + }, + ): + return relay_build(mod, target=target, params=params) + + def optimize_torch( func, example_inputs, @@ -84,10 +143,9 @@ def optimize_torch( Inputs to `torch.jit.trace`. tuning_config : tvm.meta_schedule.TuneConfig - The configuration of tuning by MetaSchedule. - We suggest users to provide their own setting, - otherwise by default setting a tuning process could be very slow, - sometimes costs a few hours. + The configuration for tuning by MetaSchedule. + If user doesn't set the config, the tuning will run with a default setting, + a number proportional to the tasks of the module. target : Optional[Union[str, Target]] The target of the compilation. @@ -106,15 +164,6 @@ def optimize_torch( if target is None: target = llvm_target() - if tuning_config is None: - # Default setting. For a better tuning result the number could be set large. - tuning_config = TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=2000, - max_trials_global=2000, - ) - # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) @@ -128,7 +177,7 @@ def optimize_torch( else: context_manager = tempfile.TemporaryDirectory() with context_manager as work_dir_path: - executor_factory = tune_relay( + executor_factory = tune_relay_auto( mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path ) From 2f39a5eff497286593068a0b0a7359c263ebf069 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 13 Jul 2022 05:04:28 -0700 Subject: [PATCH 21/27] fixed --- python/tvm/contrib/torch/optimize_torch.py | 12 +++++++++--- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 3 --- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 13574677e07b..7d02f061fdd4 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -27,6 +27,7 @@ """ import base64 import contextlib +import logging import tempfile from typing import Dict, Optional, Tuple, Union @@ -144,8 +145,8 @@ def optimize_torch( tuning_config : tvm.meta_schedule.TuneConfig The configuration for tuning by MetaSchedule. - If user doesn't set the config, the tuning will run with a default setting, - a number proportional to the tasks of the module. + If user doesn't set the config, the tuning will run with a default setting. + Here, the total number of trials is proportional to the number of tunable tasks in the input module. target : Optional[Union[str, Target]] The target of the compilation. @@ -164,6 +165,11 @@ def optimize_torch( if target is None: target = llvm_target() + if tuning_config is None: + logging.warning( + "Using the default tuning parameters. The default number of trials is set to a small value to let tuning finish quickly. For optimal performance, it is recommended to provide the `tuning_config` argument with a bigger number of trials." + ) + # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) @@ -184,4 +190,4 @@ def optimize_torch( save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") save_runtime_mod(executor_factory.module) - return GraphExecutorFactoryWrapper(torch.classes.tvm_tuning.GraphExecutorFactoryWrapper()) + return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper()) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 7d7e6a463036..12c1017bea76 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -243,9 +243,6 @@ TORCH_LIBRARY(tvm_torch, m) { [](SerializationType state) { return c10::make_intrusive(state); }); -} - -TORCH_LIBRARY(tvm_tuning, m) { m.class_("GraphExecutorFactoryWrapper") .def(torch::init<>()) .def("forward", &GraphExecutorFactoryWrapper::forward) From da484fe71be544a51969234d85c9b7c1aef2fbfa Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 13 Jul 2022 06:38:45 -0700 Subject: [PATCH 22/27] split long string --- python/tvm/contrib/torch/optimize_torch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 7d02f061fdd4..5da98681af78 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -146,7 +146,8 @@ def optimize_torch( tuning_config : tvm.meta_schedule.TuneConfig The configuration for tuning by MetaSchedule. If user doesn't set the config, the tuning will run with a default setting. - Here, the total number of trials is proportional to the number of tunable tasks in the input module. + Here, the total number of trials is proportional + to the number of tunable tasks in the input module. target : Optional[Union[str, Target]] The target of the compilation. @@ -166,9 +167,13 @@ def optimize_torch( target = llvm_target() if tuning_config is None: - logging.warning( - "Using the default tuning parameters. The default number of trials is set to a small value to let tuning finish quickly. For optimal performance, it is recommended to provide the `tuning_config` argument with a bigger number of trials." + warning_msg = ( + "Using the default tuning parameters.", + "The default number of trials is set to a small value to let tuning finish quickly.", + "For optimal performance, it is recommended to provide", + "the `tuning_config` argument with a bigger number of trials.", ) + logging.warning(" ".join(warning_msg)) # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) From fb620f65a199773c5267250691043702ec202d09 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 13 Jul 2022 19:01:04 -0700 Subject: [PATCH 23/27] tune_tir --- python/tvm/contrib/torch/as_torch.py | 70 +++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 09bd4679e498..2c3f639b29fd 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -23,12 +23,19 @@ """ as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`. """ +import tempfile from typing import Callable, List, Union import torch import torch.utils.dlpack import tvm +from tvm.meta_schedule import default_config +from tvm.meta_schedule.database.database import TuningRecord +from tvm.meta_schedule.extracted_task import ExtractedTask +from tvm.meta_schedule.tune import TuneConfig, tune_extracted_tasks +from tvm.target.target import Target +from tvm.tir.schedule.schedule import Schedule # python wrapper for OperatorModule @@ -44,8 +51,69 @@ def __init__( self.rt_module = None # runtime module self.ir_module = module # IR modules + def tune_tir_auto(self, mod: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc]): + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = self.tune_tir_inner( + mod=mod, + target=Target("llvm --num-cores=16"), + work_dir=work_dir, + ) + return sch.mod + + def tune_tir_inner( + self, + mod: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc], + target: Union[str, Target], + work_dir: str, + ): + """Tune a TIR IRModule with a given target. + + Parameters + ---------- + mod : Union[IRModule, PrimFunc] + The module to tune. + target : Union[str, Target] + The target to tune for. + work_dir : Optional[str] + The working directory to save intermediate results. + + Returns + ------- + sch : Optional[Schedule] + The tuned schedule. + """ + mod = default_config.mod(mod) + target = default_config.target(target) + + extracted_task = ExtractedTask( + task_name="main", + mod=mod, + dispatched=[mod], + target=target, + weight=1, + ) + config = TuneConfig( + # Default setting + strategy="replay_trace", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=32, + ) + database = tune_extracted_tasks( + extracted_tasks=[extracted_task], config=config, work_dir=work_dir + ) + bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1) + if not bests: + return None + assert len(bests) == 1 + sch = Schedule(mod) + bests[0].trace.apply_to_schedule(sch, remove_postproc=False) + + return sch + def build(self, target=None): - runtime_module = tvm.build(self.ir_module, target=target) + tuned_module = self.tune_tir_auto(self.ir_module) + runtime_module = tvm.build(tuned_module, target=target) func = tvm.get_global_func("tvmtorch.save_runtime_mod") func(runtime_module) From 7c5f5cb2142433e53a33a2414724158c03de6192 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 14 Jul 2022 20:10:26 -0700 Subject: [PATCH 24/27] upgrade as_torch --- apps/pt_tvmdsoop/tests/test_as_torch.py | 142 ++++++++++++++++++++++++ python/tvm/contrib/torch/as_torch.py | 55 +++++---- python/tvm/script/parser.py | 16 ++- 3 files changed, 191 insertions(+), 22 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 681453216248..a904f0869690 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -23,6 +23,7 @@ import torch.nn import tvm +from tvm.meta_schedule.tune import TuneConfig import tvm.testing from tvm.contrib.torch import as_torch from tvm.script import tir as T @@ -45,6 +46,29 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: return main +@as_torch +@T.prim_func +def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j in T.grid(32, 32): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + for ii, jj in T.grid(4, 4): + C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) + + for k in range(0, 32): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + for ii, jj, kk in T.grid(4, 4, 4): + C[vi * 4 + ii, vj * 4 + jj] = ( + C[vi * 4 + ii, vj * 4 + jj] + + A[vi * 4 + ii, vk * 4 + kk] * B[vj * 4 + jj, vk * 4 + kk] + ) + + @as_torch @tvm.script.ir_module class MyModule: @@ -79,6 +103,85 @@ def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: B[vi] = A[vi] + T.float32(1) +@as_torch +@T.prim_func +def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + with T.block(): + for i, j in T.grid(128, 128): + with T.block("s1"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + for i, j in T.grid(128, 128): + with T.block("s2"): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +config = TuneConfig( + strategy="replay_trace", + num_trials_per_iter=128, + max_trials_per_task=128, + max_trials_global=128, +) + + +@as_torch(config) +def softmax(M: int, N: int, dtype: str): + @T.prim_func + def f(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [M, N], dtype=dtype) + B = T.match_buffer(b, [M, N], dtype=dtype) + C = T.alloc_buffer((M), dtype=dtype, scope="local") + for i in T.thread_binding(0, M, thread="threadIdx.x"): + with T.block("row1"): + for j in T.parallel(N): + with T.block("column1"): + C[i] = T.max(C[i], A[i, j]) + for i in T.thread_binding(0, M, thread="blockIdx.x"): + with T.block("row2"): + for j in T.thread_binding(0, N, thread="threadIdx.x"): + with T.block("column2"): + B[i, j] = tvm.tir.exp(A[i, j] - C[i]) + for i in T.thread_binding(0, M, thread="blockIdx.x"): + with T.block("row3"): + C[i] = 0 + for j in T.parallel(N): + with T.block("column3"): + C[i] = C[i] + B[i, j] + for i in T.thread_binding(0, M, thread="blockIdx.x"): + with T.block("row4"): + for j in T.thread_binding(0, N, thread="threadIdx.x"): + with T.block("column4"): + B[i, j] = B[i, j] / C[i] + + return f + + +@as_torch(config) +@T.prim_func +def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + with T.block(): + for i, j in T.grid(128, 128): + with T.block("s1"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block("s2"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + class MinuesOnes(torch.nn.Module): def __init__(self): super(MinuesOnes, self).__init__() @@ -139,8 +242,47 @@ def test_torch_with_tvmscript(): tvm.testing.assert_allclose(ret.numpy(), ref_result, atol=1e-5, rtol=1e-5) +def test_tvmscript_torch_func_with_part_access_region(): + a1 = torch.rand(128, 128) + a2 = torch.zeros(128, 128) + a3 = torch.zeros(128, 128) + + result = a1 + 2 + + func_with_part_access_region(a1, a2, a3) + + tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_softmax(): + x = torch.rand(300, 200).cuda() + y = torch.zeros(300, 200).cuda() + + result = torch.softmax(x, axis=1).cpu().numpy() + + func = softmax(300, 200, "float32") + func(x, y) + + tvm.testing.assert_allclose(y.cpu().numpy(), result, atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_elementwise_with_root(): + a1 = torch.rand(128, 128) + a2 = torch.zeros(128, 128) + a3 = torch.zeros(128, 128) + + result = a1 + 2 + + elementwise_with_root(a1, a2, a3) + + tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_tvmscript_torch_matmul() test_tvmscript_torch_decorator() test_tvmscript_torch_gpu() test_torch_with_tvmscript() + test_tvmscript_torch_func_with_part_access_region() + test_tvmscript_torch_softmax() + test_tvmscript_torch_elementwise_with_root() diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 2c3f639b29fd..dc1882168025 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -46,10 +46,21 @@ def __init__( tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, ], + config: TuneConfig = None, ): super().__init__() self.rt_module = None # runtime module self.ir_module = module # IR modules + if config is None: + self.config = TuneConfig( + # Default setting + strategy="replay_trace", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=32, + ) + else: + self.config = config def tune_tir_auto(self, mod: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc]): with tempfile.TemporaryDirectory() as work_dir: @@ -92,15 +103,8 @@ def tune_tir_inner( target=target, weight=1, ) - config = TuneConfig( - # Default setting - strategy="replay_trace", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=32, - ) database = tune_extracted_tasks( - extracted_tasks=[extracted_task], config=config, work_dir=work_dir + extracted_tasks=[extracted_task], config=self.config, work_dir=work_dir ) bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1) if not bests: @@ -131,26 +135,37 @@ def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: return self.rt_module.forward(torch_inputs) -def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]): +def as_torch(inp): """A decorator of converting TensorIR to PyTorch nn.Module. Parameters ---------- - func : Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] - The function to be parsed. - + config: Optional[TuneConfig] + A configuration for MetaSchedular tuning. Returns ------- - mod : OperatorModuleWrapper - It will return an object of OperatorModuleWrapper, + mod : Union[OperatorModuleWrapper, Callable] + It will return an object, or a templated function of OperatorModuleWrapper, which is the subclass of the original nn.Module. + """ - if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): - return OperatorModuleWrapper(func) - if isinstance(func, Callable): - def func_get_param(*args, **kargs): - return OperatorModuleWrapper(func(*args, **kargs)) + def as_torch_inner(func): + if isinstance(inp, TuneConfig): + config = inp + else: + config = None + if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): + return OperatorModuleWrapper(func, config) + if isinstance(func, Callable): + + def func_get_param(*args, **kargs): + return OperatorModuleWrapper(func(*args, **kargs), config) + + return func_get_param + raise Exception("Incorrect `as_torch` formatting.") - return func_get_param + if isinstance(inp, TuneConfig): + return as_torch_inner + return as_torch_inner(inp) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index e4bdd1206506..9edf08a5ba72 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -435,11 +435,23 @@ def my_function(x: T.handle): # 1. Argument types T.evaluate(0) # 4. This function returns 0 """ + def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]): + if isinstance(decorator, ast.Call): + if len(decorator.params) != 1: + return False + func_name = decorator.func_name + else: + func_name = decorator + if isinstance(func_name, ast.Var): + return func_name.id.name == "as_torch" + def check_decorator(decorators: List[ast.Expr]) -> bool: """Check the decorator is `T.prim_func""" - if len(decorators) != 1: + if len(decorators) > 2 or len(decorators) == 0: + return False + if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]): return False - d: ast.Expr = decorators[0] + d: ast.Expr = decorators[-1] return ( isinstance(d, ast.Attr) and isinstance(d.object, ast.Var) From 14a6cd1f5efc5cfb1e0c8fda6444802faf006347 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 14 Jul 2022 22:15:36 -0700 Subject: [PATCH 25/27] optimize as_torch --- apps/pt_tvmdsoop/tests/test_as_torch.py | 145 ++++++++------------- python/tvm/contrib/torch/as_torch.py | 81 +++--------- python/tvm/contrib/torch/optimize_torch.py | 4 +- 3 files changed, 73 insertions(+), 157 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index a904f0869690..86bd461bd2c6 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -46,47 +46,6 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: return main -@as_torch -@T.prim_func -def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - - for i, j in T.grid(32, 32): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i, j]) - for ii, jj in T.grid(4, 4): - C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) - - for k in range(0, 32): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - for ii, jj, kk in T.grid(4, 4, 4): - C[vi * 4 + ii, vj * 4 + jj] = ( - C[vi * 4 + ii, vj * 4 + jj] - + A[vi * 4 + ii, vk * 4 + kk] * B[vj * 4 + jj, vk * 4 + kk] - ) - - -@as_torch -@tvm.script.ir_module -class MyModule: - @T.prim_func - def main(a: T.handle, b: T.handle): - # We exchange data between function by handles, which are similar to pointer. - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # Create buffer from handles. - A = T.match_buffer(a, (8,), dtype="float32") - B = T.match_buffer(b, (8,), dtype="float32") - for i in range(8): - # A block is an abstraction for computation. - with T.block("B"): - # Define a spatial block iterator and bind it to value i. - vi = T.axis.spatial(8, i) - B[vi] = A[vi] + 1.0 - - @as_torch @tvm.script.ir_module class ModuleGPU: @@ -133,53 +92,59 @@ def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: @as_torch(config) -def softmax(M: int, N: int, dtype: str): +@tvm.script.ir_module +class MyModule: @T.prim_func - def f(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [M, N], dtype=dtype) - B = T.match_buffer(b, [M, N], dtype=dtype) - C = T.alloc_buffer((M), dtype=dtype, scope="local") - for i in T.thread_binding(0, M, thread="threadIdx.x"): - with T.block("row1"): - for j in T.parallel(N): - with T.block("column1"): - C[i] = T.max(C[i], A[i, j]) - for i in T.thread_binding(0, M, thread="blockIdx.x"): - with T.block("row2"): - for j in T.thread_binding(0, N, thread="threadIdx.x"): - with T.block("column2"): - B[i, j] = tvm.tir.exp(A[i, j] - C[i]) - for i in T.thread_binding(0, M, thread="blockIdx.x"): - with T.block("row3"): - C[i] = 0 - for j in T.parallel(N): - with T.block("column3"): - C[i] = C[i] + B[i, j] - for i in T.thread_binding(0, M, thread="blockIdx.x"): - with T.block("row4"): - for j in T.thread_binding(0, N, thread="threadIdx.x"): - with T.block("column4"): - B[i, j] = B[i, j] / C[i] - - return f + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 @as_torch(config) @T.prim_func -def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) +def loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] - with T.block(): - for i, j in T.grid(128, 128): - with T.block("s1"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] + T.float32(1) - for i, j in T.grid(128, 128): - with T.block("s2"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, vj] + T.float32(1) + +@as_torch(config) +def elementwise_with_root(M: int, N: int, dtype: str): + @T.prim_func + def f(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, N]) + B = T.match_buffer(b, [M, N]) + C = T.match_buffer(c, [M, N]) + + with T.block(): + for i, j in T.grid(M, N): + with T.block("s1"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + for i, j in T.grid(M, N): + with T.block("s2"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + return f class MinuesOnes(torch.nn.Module): @@ -254,14 +219,13 @@ def test_tvmscript_torch_func_with_part_access_region(): tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) -def test_tvmscript_torch_softmax(): - x = torch.rand(300, 200).cuda() - y = torch.zeros(300, 200).cuda() +def test_tvmscript_torch_loop_split(): + x = torch.rand(128, 128).cuda() + y = torch.zeros(128).cuda() - result = torch.softmax(x, axis=1).cpu().numpy() + result = torch.sum(x.cpu(), dim=1).numpy() - func = softmax(300, 200, "float32") - func(x, y) + loop_split(x, y) tvm.testing.assert_allclose(y.cpu().numpy(), result, atol=1e-5, rtol=1e-5) @@ -273,7 +237,8 @@ def test_tvmscript_torch_elementwise_with_root(): result = a1 + 2 - elementwise_with_root(a1, a2, a3) + func = elementwise_with_root(128, 128, "float32") + func(a1, a2, a3) tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) @@ -284,5 +249,5 @@ def test_tvmscript_torch_elementwise_with_root(): test_tvmscript_torch_gpu() test_torch_with_tvmscript() test_tvmscript_torch_func_with_part_access_region() - test_tvmscript_torch_softmax() + test_tvmscript_torch_loop_split() test_tvmscript_torch_elementwise_with_root() diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index dc1882168025..44e8ab36e8b9 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -30,10 +30,7 @@ import torch.utils.dlpack import tvm -from tvm.meta_schedule import default_config -from tvm.meta_schedule.database.database import TuningRecord -from tvm.meta_schedule.extracted_task import ExtractedTask -from tvm.meta_schedule.tune import TuneConfig, tune_extracted_tasks +from tvm.meta_schedule.tune import TuneConfig, tune_tir from tvm.target.target import Target from tvm.tir.schedule.schedule import Schedule @@ -51,73 +48,26 @@ def __init__( super().__init__() self.rt_module = None # runtime module self.ir_module = module # IR modules - if config is None: - self.config = TuneConfig( - # Default setting - strategy="replay_trace", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=32, - ) - else: - self.config = config + self.config = config - def tune_tir_auto(self, mod: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc]): + def tune_tir_auto(self, mod, target): + if target is None: + target = Target("llvm --num-cores=16") with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = self.tune_tir_inner( + sch: Schedule = tune_tir( mod=mod, - target=Target("llvm --num-cores=16"), + target=target, + config=self.config, work_dir=work_dir, ) return sch.mod - def tune_tir_inner( - self, - mod: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc], - target: Union[str, Target], - work_dir: str, - ): - """Tune a TIR IRModule with a given target. - - Parameters - ---------- - mod : Union[IRModule, PrimFunc] - The module to tune. - target : Union[str, Target] - The target to tune for. - work_dir : Optional[str] - The working directory to save intermediate results. - - Returns - ------- - sch : Optional[Schedule] - The tuned schedule. - """ - mod = default_config.mod(mod) - target = default_config.target(target) - - extracted_task = ExtractedTask( - task_name="main", - mod=mod, - dispatched=[mod], - target=target, - weight=1, - ) - database = tune_extracted_tasks( - extracted_tasks=[extracted_task], config=self.config, work_dir=work_dir - ) - bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1) - if not bests: - return None - assert len(bests) == 1 - sch = Schedule(mod) - bests[0].trace.apply_to_schedule(sch, remove_postproc=False) - - return sch - def build(self, target=None): - tuned_module = self.tune_tir_auto(self.ir_module) - runtime_module = tvm.build(tuned_module, target=target) + if self.config is not None: + module = self.tune_tir_auto(self.ir_module, target) + else: + module = self.ir_module + runtime_module = tvm.build(module, target=target) func = tvm.get_global_func("tvmtorch.save_runtime_mod") func(runtime_module) @@ -126,7 +76,7 @@ def build(self, target=None): def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: if self.rt_module is None: if torch_inputs[0].is_cuda: - self.build(target="cuda") + self.build(target=Target("nvidia/geforce-rtx-3070")) elif torch_inputs[0].device.type == "cpu": self.build() else: @@ -141,7 +91,8 @@ def as_torch(inp): Parameters ---------- config: Optional[TuneConfig] - A configuration for MetaSchedular tuning. + The configuration for tuning by MetaSchedule. + If user doesn't set the config, the tuning will run with a default setting. Returns ------- diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 5da98681af78..23636e36b714 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -27,9 +27,9 @@ """ import base64 import contextlib -import logging import tempfile from typing import Dict, Optional, Tuple, Union +import warnings import torch import torch.utils.dlpack @@ -173,7 +173,7 @@ def optimize_torch( "For optimal performance, it is recommended to provide", "the `tuning_config` argument with a bigger number of trials.", ) - logging.warning(" ".join(warning_msg)) + warnings.warn(" ".join(warning_msg)) # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) From 150ca1f9547a48191ac707b9d6ee6034f90981c4 Mon Sep 17 00:00:00 2001 From: juda Date: Fri, 15 Jul 2022 05:32:08 -0700 Subject: [PATCH 26/27] as_torch --- apps/pt_tvmdsoop/tests/test_as_torch.py | 10 +++- python/tvm/contrib/torch/as_torch.py | 68 +++++++++++----------- python/tvm/contrib/torch/optimize_torch.py | 2 +- 3 files changed, 43 insertions(+), 37 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 86bd461bd2c6..2c454e9454e7 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -24,6 +24,7 @@ import tvm from tvm.meta_schedule.tune import TuneConfig +from tvm.target.target import Target import tvm.testing from tvm.contrib.torch import as_torch from tvm.script import tir as T @@ -91,7 +92,7 @@ def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: ) -@as_torch(config) +@as_torch @tvm.script.ir_module class MyModule: @T.prim_func @@ -109,7 +110,7 @@ def main(a: T.handle, b: T.handle): B[vi] = A[vi] + 1.0 -@as_torch(config) +@as_torch @T.prim_func def loop_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") @@ -126,7 +127,7 @@ def loop_split(a: T.handle, b: T.handle) -> None: B[vi] = B[vi] + A[vi, vk] -@as_torch(config) +@as_torch def elementwise_with_root(M: int, N: int, dtype: str): @T.prim_func def f(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -214,6 +215,7 @@ def test_tvmscript_torch_func_with_part_access_region(): result = a1 + 2 + func_with_part_access_region.tune() func_with_part_access_region(a1, a2, a3) tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) @@ -225,6 +227,7 @@ def test_tvmscript_torch_loop_split(): result = torch.sum(x.cpu(), dim=1).numpy() + loop_split.tune(config, Target("nvidia/geforce-rtx-3070")) loop_split(x, y) tvm.testing.assert_allclose(y.cpu().numpy(), result, atol=1e-5, rtol=1e-5) @@ -238,6 +241,7 @@ def test_tvmscript_torch_elementwise_with_root(): result = a1 + 2 func = elementwise_with_root(128, 128, "float32") + func.tune(config) func(a1, a2, a3) tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index 44e8ab36e8b9..3a2b4dda9ea9 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -43,31 +43,45 @@ def __init__( tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, ], - config: TuneConfig = None, ): super().__init__() self.rt_module = None # runtime module self.ir_module = module # IR modules - self.config = config - def tune_tir_auto(self, mod, target): + def tune(self, config: TuneConfig = None, target: Union[str, Target] = None): + """ + Tune the TVMscript code. + + Parameters + ---------- + config: Optional[TuneConfig] + The tuning configuration. + + target : Optional[str, Target] + The target to tune for. + """ + if config is None: + config = TuneConfig( + # Default setting + strategy="replay_trace", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=32, + ) if target is None: target = Target("llvm --num-cores=16") with tempfile.TemporaryDirectory() as work_dir: sch: Schedule = tune_tir( - mod=mod, + mod=self.ir_module, target=target, - config=self.config, + config=config, work_dir=work_dir, ) - return sch.mod + self.ir_module = sch.mod + self.build(target) def build(self, target=None): - if self.config is not None: - module = self.tune_tir_auto(self.ir_module, target) - else: - module = self.ir_module - runtime_module = tvm.build(module, target=target) + runtime_module = tvm.build(self.ir_module, target=target) func = tvm.get_global_func("tvmtorch.save_runtime_mod") func(runtime_module) @@ -76,7 +90,7 @@ def build(self, target=None): def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: if self.rt_module is None: if torch_inputs[0].is_cuda: - self.build(target=Target("nvidia/geforce-rtx-3070")) + self.build(target="cuda") elif torch_inputs[0].device.type == "cpu": self.build() else: @@ -85,14 +99,13 @@ def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: return self.rt_module.forward(torch_inputs) -def as_torch(inp): +def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]): """A decorator of converting TensorIR to PyTorch nn.Module. Parameters ---------- - config: Optional[TuneConfig] - The configuration for tuning by MetaSchedule. - If user doesn't set the config, the tuning will run with a default setting. + func: Optional[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] + The function written by TVMscript. Returns ------- @@ -101,22 +114,11 @@ def as_torch(inp): which is the subclass of the original nn.Module. """ + if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): + return OperatorModuleWrapper(func) + if isinstance(func, Callable): - def as_torch_inner(func): - if isinstance(inp, TuneConfig): - config = inp - else: - config = None - if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): - return OperatorModuleWrapper(func, config) - if isinstance(func, Callable): - - def func_get_param(*args, **kargs): - return OperatorModuleWrapper(func(*args, **kargs), config) - - return func_get_param - raise Exception("Incorrect `as_torch` formatting.") + def func_get_param(*args, **kargs): + return OperatorModuleWrapper(func(*args, **kargs)) - if isinstance(inp, TuneConfig): - return as_torch_inner - return as_torch_inner(inp) + return func_get_param diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 23636e36b714..efc3b699a52a 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -173,7 +173,7 @@ def optimize_torch( "For optimal performance, it is recommended to provide", "the `tuning_config` argument with a bigger number of trials.", ) - warnings.warn(" ".join(warning_msg)) + warnings.warn(" ".join(warning_msg), stacklevel=2) # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) From 7093b1335270c6234a98727d946bf082619b538d Mon Sep 17 00:00:00 2001 From: juda Date: Mon, 25 Jul 2022 18:55:33 -0700 Subject: [PATCH 27/27] fixed typo --- python/tvm/contrib/torch/optimize_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index efc3b699a52a..282e6c5dc84f 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -21,7 +21,7 @@ # pylint: disable=missing-class-docstring # pylint: disable=missing-function-docstring """ -optimize_torch: aa function similar to `torch.jit.trace`, +optimize_torch: a function similar to `torch.jit.trace`, which is used to optimize the `torch.nn.module` by TVM metaSchedule, and returns a custom TorchScript operator """