diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py new file mode 100644 index 000000000000..2c454e9454e7 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import numpy as np + +import torch +import torch.nn + +import tvm +from tvm.meta_schedule.tune import TuneConfig +from tvm.target.target import Target +import tvm.testing +from tvm.contrib.torch import as_torch +from tvm.script import tir as T + + +@as_torch +def matmul(M: int, N: int, K: int, dtype: str): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return main + + +@as_torch +@tvm.script.ir_module +class ModuleGPU: + @T.prim_func + def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_2 in T.thread_binding(2, thread="threadIdx.x"): + for i_1 in T.serial(2): + with T.block("B"): + vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) + T.reads(A[vi]) + T.writes(B[vi]) + B[vi] = A[vi] + T.float32(1) + + +@as_torch +@T.prim_func +def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + with T.block(): + for i, j in T.grid(128, 128): + with T.block("s1"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + for i, j in T.grid(128, 128): + with T.block("s2"): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +config = TuneConfig( + strategy="replay_trace", + num_trials_per_iter=128, + max_trials_per_task=128, + max_trials_global=128, +) + + +@as_torch +@tvm.script.ir_module +class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 + + +@as_torch +@T.prim_func +def loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@as_torch +def elementwise_with_root(M: int, N: int, dtype: str): + @T.prim_func + def f(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, N]) + B = T.match_buffer(b, [M, N]) + C = T.match_buffer(c, [M, N]) + + with T.block(): + for i, j in T.grid(M, N): + with T.block("s1"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + for i, j in T.grid(M, N): + with T.block("s2"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + return f + + +class MinuesOnes(torch.nn.Module): + def __init__(self): + super(MinuesOnes, self).__init__() + self.engine = MyModule + + def forward(self, *input): + self.engine.forward(*input) + return input[-1] - 1 + + +def test_tvmscript_torch_matmul(): + s1 = np.random.rand(128, 128).astype("float32") + s2 = np.random.rand(128, 128).astype("float32") + s3 = np.random.rand(128, 128).astype("float32") + + q1 = torch.from_numpy(s1) + q2 = torch.from_numpy(s2) + q3 = torch.from_numpy(s3) + + numpy_result = np.matmul(s1, np.transpose(s2)) + + nn_module = matmul(128, 128, 128, "float32") + + nn_module(q1, q2, q3) + + tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_decorator(): + q1 = torch.arange(8).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32) + + MyModule(q1, q2) + + tvm.testing.assert_allclose(q2.numpy(), (q1 + 1).numpy(), atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_gpu(): + cuda0 = torch.device("cuda:0") + q1 = torch.arange(8, device=cuda0).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) + + ModuleGPU(q1, q2) + + tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5) + + +def test_torch_with_tvmscript(): + ref_result = np.arange(8).astype("float32") + + q1 = torch.arange(8).type(torch.float32) + q2 = torch.zeros((8,), dtype=torch.float32) + + nn_module = MinuesOnes() + + ret = nn_module.forward(q1, q2) + + tvm.testing.assert_allclose(ret.numpy(), ref_result, atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_func_with_part_access_region(): + a1 = torch.rand(128, 128) + a2 = torch.zeros(128, 128) + a3 = torch.zeros(128, 128) + + result = a1 + 2 + + func_with_part_access_region.tune() + func_with_part_access_region(a1, a2, a3) + + tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_loop_split(): + x = torch.rand(128, 128).cuda() + y = torch.zeros(128).cuda() + + result = torch.sum(x.cpu(), dim=1).numpy() + + loop_split.tune(config, Target("nvidia/geforce-rtx-3070")) + loop_split(x, y) + + tvm.testing.assert_allclose(y.cpu().numpy(), result, atol=1e-5, rtol=1e-5) + + +def test_tvmscript_torch_elementwise_with_root(): + a1 = torch.rand(128, 128) + a2 = torch.zeros(128, 128) + a3 = torch.zeros(128, 128) + + result = a1 + 2 + + func = elementwise_with_root(128, 128, "float32") + func.tune(config) + func(a1, a2, a3) + + tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + test_tvmscript_torch_matmul() + test_tvmscript_torch_decorator() + test_tvmscript_torch_gpu() + test_torch_with_tvmscript() + test_tvmscript_torch_func_with_part_access_region() + test_tvmscript_torch_loop_split() + test_tvmscript_torch_elementwise_with_root() diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py new file mode 100644 index 000000000000..258dfe55c43f --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py @@ -0,0 +1,161 @@ +# pylint: disable=missing-class-docstring +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import tempfile + +import torch +from torch.utils import benchmark +from torchvision.models import resnet18 + +import tvm +import tvm.testing +from tvm.contrib.torch import optimize_torch +from tvm.meta_schedule import TuneConfig + + +def test_matmul_tuning_relay(): + def matmul(x, w): + return torch.matmul(x, w) + + x = torch.randn(15, 20) + w = torch.randn(20, 30) + example_inputs = (x, w) + + rt_mod = optimize_torch(matmul, example_inputs) + torch_answer = torch.matmul(x, w).numpy() + tvm_answer = rt_mod(x, w).numpy() + + tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5) + + +class InnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 20, 5) + + def forward(self, x): + return torch.nn.functional.relu(self.conv(x)) + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(20, 20, 5) + self.relu = InnerModel() + + def forward(self, x): + x = self.relu(x) + return torch.nn.functional.relu(self.conv(x)) + + +def test_nested_module(): + simple_module = SimpleModel() + example_input = torch.randn(20, 1, 10, 10) + optimized_module = optimize_torch(simple_module, example_input) + ret1 = simple_module(example_input).detach().numpy() + ret2 = optimized_module(example_input).detach().numpy() + tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) + + +def test_save_load_function(): + def foo(x): + return 2 * x + 1 + + example_input = torch.rand(3) + opt_foo = optimize_torch(foo, example_input) + ret1 = opt_foo(example_input) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + torch.save(opt_foo, tmp.name) + loaded_mod = torch.load(tmp.name) + ret2 = loaded_mod(example_input) + tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5) + + +class MyResNet18(torch.nn.Module): + def __init__(self, config, target=None): + super(MyResNet18, self).__init__() + self.means = torch.nn.Parameter( + torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1) + ).cuda() + self.resnet = optimize_torch(resnet18(), [torch.rand(1, 3, 224, 224)], config, target) + + def forward(self, input): + return self.resnet(input - self.means) + + +class JitModule(torch.nn.Module): + def __init__(self): + super(JitModule, self).__init__() + self.means = torch.nn.Parameter( + torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1) + ).cuda() + self.resnet = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval())) + + def forward(self, input): + return self.resnet(input - self.means) + + +# default config for testing +config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=4, + max_trials_per_task=8, + max_trials_global=16, +) + +if torch.cuda.is_available(): + target_cuda = "nvidia/geforce-rtx-3070" + meta_module_resnet18 = MyResNet18(config, target_cuda) + jit_module_resnet18 = JitModule() + + +def compare_optimize_resnet18_to_torchscript(): + results = [] + for i in range(20): + test_input = torch.rand(1, 3, 224, 224).half().cuda() + sub_label = f"[test {i}]" + results.append( + benchmark.Timer( + stmt="meta_module_resnet18(test_input)", + setup="from __main__ import meta_module_resnet18", + globals={"test_input": test_input}, + sub_label=sub_label, + description="tuning by meta", + ).blocked_autorange() + ) + results.append( + benchmark.Timer( + stmt="jit_module_resnet18(test_input)", + setup="from __main__ import jit_module_resnet18", + globals={"test_input": test_input}, + sub_label=sub_label, + description="tuning by jit", + ).blocked_autorange() + ) + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + test_matmul_tuning_relay() + test_nested_module() + test_save_load_function() + if torch.cuda.is_available(): + compare_optimize_resnet18_to_torchscript() diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 720ac29cc6e2..340f9cef9e58 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -20,7 +20,6 @@ import platform import torch from tvm._ffi import libinfo -from tvm.relay.frontend import pytorch def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): @@ -39,6 +38,7 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): _load_platform_specific_library() + from . import module GraphModule = module.GraphModule @@ -49,3 +49,13 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule compile = pytorch_tvm.compile + +from . import as_torch + +TVMScriptIRModule = as_torch.OperatorModuleWrapper +as_torch = as_torch.as_torch + +from . import optimize_torch + +GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper +optimize_torch = optimize_torch.optimize_torch diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py new file mode 100644 index 000000000000..3a2b4dda9ea9 --- /dev/null +++ b/python/tvm/contrib/torch/as_torch.py @@ -0,0 +1,124 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`. +""" +import tempfile +from typing import Callable, List, Union + +import torch +import torch.utils.dlpack + +import tvm +from tvm.meta_schedule.tune import TuneConfig, tune_tir +from tvm.target.target import Target +from tvm.tir.schedule.schedule import Schedule + + +# python wrapper for OperatorModule +class OperatorModuleWrapper(torch.nn.Module): + def __init__( + self, + module: Union[ + tvm.ir.module.IRModule, + tvm.tir.function.PrimFunc, + ], + ): + super().__init__() + self.rt_module = None # runtime module + self.ir_module = module # IR modules + + def tune(self, config: TuneConfig = None, target: Union[str, Target] = None): + """ + Tune the TVMscript code. + + Parameters + ---------- + config: Optional[TuneConfig] + The tuning configuration. + + target : Optional[str, Target] + The target to tune for. + """ + if config is None: + config = TuneConfig( + # Default setting + strategy="replay_trace", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=32, + ) + if target is None: + target = Target("llvm --num-cores=16") + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=self.ir_module, + target=target, + config=config, + work_dir=work_dir, + ) + self.ir_module = sch.mod + self.build(target) + + def build(self, target=None): + runtime_module = tvm.build(self.ir_module, target=target) + func = tvm.get_global_func("tvmtorch.save_runtime_mod") + func(runtime_module) + + self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper() + + def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]: + if self.rt_module is None: + if torch_inputs[0].is_cuda: + self.build(target="cuda") + elif torch_inputs[0].device.type == "cpu": + self.build() + else: + raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet") + + return self.rt_module.forward(torch_inputs) + + +def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]): + """A decorator of converting TensorIR to PyTorch nn.Module. + + Parameters + ---------- + func: Optional[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable] + The function written by TVMscript. + + Returns + ------- + mod : Union[OperatorModuleWrapper, Callable] + It will return an object, or a templated function of OperatorModuleWrapper, + which is the subclass of the original nn.Module. + + """ + if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): + return OperatorModuleWrapper(func) + if isinstance(func, Callable): + + def func_get_param(*args, **kargs): + return OperatorModuleWrapper(func(*args, **kargs)) + + return func_get_param diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py new file mode 100644 index 000000000000..282e6c5dc84f --- /dev/null +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -0,0 +1,198 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: a function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Dict, Optional, Tuple, Union +import warnings + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.meta_schedule import TuneConfig, default_config +from tvm.meta_schedule.apply_history_best import ApplyHistoryBest +from tvm.meta_schedule.relay_integration import extract_task_from_relay +from tvm.meta_schedule.tune import tune_extracted_tasks +from tvm.meta_schedule.utils import autotvm_silencer +from tvm.runtime import vm +from tvm.runtime.module import Module +from tvm.runtime.ndarray import NDArray +from tvm.target.target import Target + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def tune_relay_auto( + mod: IRModule, + target: Union[str, Target], + config: TuneConfig, + work_dir: str, + backend: str = "graph", + params: Optional[Dict[str, NDArray]] = None, +) -> Union[Module, vm.Executable]: + """A wrapper of `tune_relay` but provide a default setting for the config. + + Parameters + ---------- + mod : IRModule + The module to tune. + target : Union[str, Target] + The target to tune for. + config : TuneConfig + The search strategy config. + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + work_dir : Optional[str] + The working directory to save intermediate results. + backend : str = "graph" + The backend to use for relay compilation(graph / vm). + + Returns + ------- + lib : Union[Module, tvm.runtime.vm.Executable] + The built runtime module or vm Executable for the given relay workload. + """ + target = default_config.target(target) + extracted_tasks = extract_task_from_relay(mod, target, params) + if config is None: + config = TuneConfig( + num_trials_per_iter=16, + max_trials_global=16 * len(extracted_tasks), + ) + database = tune_extracted_tasks(extracted_tasks, config, work_dir) + relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend] + with target, autotvm_silencer(), ApplyHistoryBest(database): + with PassContext( + opt_level=3, + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", + }, + ): + return relay_build(mod, target=target, params=params) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + Inputs to `torch.jit.trace`. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration for tuning by MetaSchedule. + If user doesn't set the config, the tuning will run with a default setting. + Here, the total number of trials is proportional + to the number of tunable tasks in the input module. + + target : Optional[Union[str, Target]] + The target of the compilation. + If user doesn't set the target, the module will be built for the CPU target. + + work_dir : Optional[str] + The working directory to save intermediate results. + + Returns + ------- + mod : GraphExecutorFactoryWrapper + It will return an object of GraphExecutorFactoryWrapper, + which is the subclass of the original nn.Module. + """ + + if target is None: + target = llvm_target() + + if tuning_config is None: + warning_msg = ( + "Using the default tuning parameters.", + "The default number of trials is set to a small value to let tuning finish quickly.", + "For optimal performance, it is recommended to provide", + "the `tuning_config` argument with a bigger number of trials.", + ) + warnings.warn(" ".join(warning_msg), stacklevel=2) + + # If `func` is already a traced module this statement makes no effect + jit_mod = torch.jit.trace(func, example_inputs) + + if isinstance(example_inputs, torch.Tensor): + example_inputs = [example_inputs] + + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule + if work_dir: + context_manager = contextlib.nullcontext(work_dir) + else: + context_manager = tempfile.TemporaryDirectory() + with context_manager as work_dir_path: + executor_factory = tune_relay_auto( + mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path + ) + + save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") + save_runtime_mod(executor_factory.module) + + return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper()) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index e4bdd1206506..9edf08a5ba72 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -435,11 +435,23 @@ def my_function(x: T.handle): # 1. Argument types T.evaluate(0) # 4. This function returns 0 """ + def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]): + if isinstance(decorator, ast.Call): + if len(decorator.params) != 1: + return False + func_name = decorator.func_name + else: + func_name = decorator + if isinstance(func_name, ast.Var): + return func_name.id.name == "as_torch" + def check_decorator(decorators: List[ast.Expr]) -> bool: """Check the decorator is `T.prim_func""" - if len(decorators) != 1: + if len(decorators) > 2 or len(decorators) == 0: + return False + if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]): return False - d: ast.Expr = decorators[0] + d: ast.Expr = decorators[-1] return ( isinstance(d, ast.Attr) and isinstance(d.object, ast.Var) diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h new file mode 100644 index 000000000000..859fd1abcfd0 --- /dev/null +++ b/src/contrib/torch/base64.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file base64.h + * \brief Util functions for converting plain bytes back to plain bytes + */ + +#ifndef TVM_CONTRIB_TORCH_BASE64_H_ +#define TVM_CONTRIB_TORCH_BASE64_H_ + +#include + +#include +#include +#include + +#include "../../support/base64.h" + +namespace tvm { +namespace support { + +size_t b64strlen(const std::string b64str) { + ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; + size_t length = b64str.size() / 4 * 3; + if (b64str[b64str.size() - 2] == '=') { + length -= 2; + } else if (b64str[b64str.size() - 1] == '=') { + length -= 1; + } + return length; +} + +void b64decode(const std::string b64str, u_char* ret) { + size_t index = 0; + const auto length = b64str.size(); + for (size_t i = 0; i < length; i += 4) { + int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]]; + int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]]; + int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]]; + int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]]; + u_char st1 = (ch0 << 2) + (ch1 >> 4); + ret[index++] = st1; + if (b64str[i + 2] != '=') { + u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2); + ret[index++] = st2; + if (b64str[i + 3] != '=') { + u_char st3 = ((ch2 & 0b11) << 6) + ch3; + ret[index++] = st3; + } + } + } + ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; +} + +} // namespace support +} // namespace tvm + +#endif // TVM_CONTRIB_TORCH_BASE64_H_ diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc new file mode 100644 index 000000000000..12c1017bea76 --- /dev/null +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" + +namespace tvm { +namespace contrib { + +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +using SerializationType = std::string; // base64 stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_base64` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) + << "Failed to remove temporary file (" << file_name << ")"; + } + std::string file_name; +}; + +tvm::runtime::Module deserialize(SerializationType state) { + auto length = tvm::support::b64strlen(state); + + std::vector bytes(length); + tvm::support::b64decode(state, bytes.data()); + + const std::string name = tmpnam(NULL); + auto file_name = name + ".so"; + std::unique_ptr pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); + fwrite(bytes.data(), sizeof(u_char), length, pFile.get()); + fflush(pFile.get()); + + std::string load_f_name = "runtime.module.loadfile_so"; + const PackedFunc* f = runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + + tvm::runtime::Module ret = (*f)(file_name, ""); + + return ret; +} + +/** + * @brief A Torch's module which wraps TVM's OperatorModule Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class OperatorModuleWrapper : public torch::jit::CustomClassHolder { + public: + OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } + + void forward(const c10::List& inputs) { + int input_length = inputs.size(); + + std::vector tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__"); + + std::vector tvm_values(input_length); + std::vector tvm_type_codes(input_length); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + for (int k = 0; k < input_length; ++k) { + setter(k, &tensors[k]->dl_tensor); + } + + run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length), + nullptr); + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + } + + SerializationType Serialize() { return serialize(runtime_module); } + + explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } + + private: + tvm::runtime::Module runtime_module; +}; + +tvm::Device getDevice(const at::Tensor& tensor) { + tvm::Device dev; + dev.device_id = tensor.get_device(); + switch (tensor.device().type()) { + case at::DeviceType::CPU: + dev.device_type = DLDeviceType::kDLCPU; + if (dev.device_id == -1) { + /* + * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning + * Thus we manually set the device ID as 0 for avoiding potentially error of index out of + * bounds + */ + dev.device_id = 0; + } + break; + case at::DeviceType::CUDA: + dev.device_type = DLDeviceType::kDLCUDA; + break; + default: + TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str()); + } + return dev; +} + +/** + * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { + public: + explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) + : executor_factory_(executor_factory) { + CHECK(executor_factory_->IsInstance()) + << "module is not an instance of GraphExecutorFactory"; + } + + GraphExecutorFactoryWrapper() + : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {} + + c10::List forward(const c10::List& inputs) { + int input_length = inputs.size(); + + if (!executor_.defined()) { + TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); + DLDevice input_device = getDevice(inputs.get(0)); + + auto tmp = executor_factory_.GetFunction("default"); + + executor_ = tmp(input_device); + } + + std::vector tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + tvm::runtime::PackedFunc run = executor_.GetFunction("run"); + tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input"); + tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output"); + tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs"); + + for (int k = 0; k < input_length; ++k) { + set_input(k, &tensors[k]->dl_tensor); + } + + run(); + + int64_t output_length = get_num_outputs(); + + c10::List outputs; + outputs.reserve(output_length); + + for (int k = 0; k < output_length; ++k) { + tvm::runtime::NDArray results = get_output(k); + at::Tensor atTensor = at::fromDLPack(results.ToDLPack()); + outputs.emplace_back(atTensor); + } + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + return outputs; + } + + SerializationType Serialize() { return serialize(executor_factory_); } + + explicit GraphExecutorFactoryWrapper(SerializationType state) { + executor_factory_ = deserialize(state); + } + + private: + tvm::runtime::Module executor_factory_; + tvm::runtime::Module executor_; +}; + +TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { + ThreadLocalStore::ThreadLocal()->mod = mod; +}); + +TORCH_LIBRARY(tvm_torch, m) { + m.class_("OperatorModuleWrapper") + .def(torch::init<>()) + .def("forward", &OperatorModuleWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializationType { + return self->Serialize(); + }, + [](SerializationType state) { + return c10::make_intrusive(state); + }); + m.class_("GraphExecutorFactoryWrapper") + .def(torch::init<>()) + .def("forward", &GraphExecutorFactoryWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializationType { + return self->Serialize(); + }, + [](SerializationType state) { + return c10::make_intrusive(state); + }); +} + +} // namespace contrib +} // namespace tvm