-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch]Add PyTorchTVM: compile torchscript to tvm and export as pytorch_op #8777
Changes from 3 commits
3ba9d1a
39a609b
e38511c
6d26809
ca5986e
7149066
b1122d8
e49ee95
390a9ca
7afa024
17e486a
7992bf4
4a295b3
4db2d9e
bc3baeb
c2d2812
910dd33
8a3fed1
835b8a1
f94e8e9
8f7ec5a
3320d00
391167f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
cmake_minimum_required(VERSION 3.2) | ||
project(tf_tvmdsoop C CXX) | ||
|
||
set(TFTVM_COMPILE_FLAGS -std=c++14) | ||
set(BUILD_TVMDSOOP_ONLY ON) | ||
set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) | ||
set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) | ||
|
||
include_directories(SYSTEM ${TVM_ROOT}/3rdparty/dlpack/include/) | ||
include_directories(SYSTEM ${TVM_ROOT}/3rdparty/dmlc-core/include/) | ||
include_directories(${TVM_ROOT}/include) | ||
|
||
link_directories(${TVM_ROOT}/build) | ||
|
||
include(${TVM_ROOT}/cmake/utils/FindCUDA.cmake) | ||
include(${TVM_ROOT}/cmake/modules/CUDA.cmake) | ||
|
||
include(${TVM_ROOT}/cmake/modules/contrib/PT_TVMCLASS.cmake) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/bin/bash | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
TVM_ROOT=$(cd $(dirname $0)/../..; pwd) | ||
echo "TVM_ROOT=${TVM_ROOT}" | ||
|
||
export PYTHONPATH=${TVM_ROOT}/python | ||
|
||
python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1 | ||
if [ "$?" -eq 0 ]; then | ||
echo "Build PT_TVMCLASS with gpu support and execute tests" | ||
CMAKE_OPTIONS="-DUSE_CUDA=/data00/liuxin.ai/cuda_111 -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update /data00/liuxin.ai/cuda_111 |
||
|
||
mkdir -p build | ||
cd build; cmake .. ${CMAKE_OPTIONS} && make | ||
cd .. | ||
|
||
LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests | ||
fi | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env python | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Test script for tf op module""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pt |
||
import torch | ||
import time | ||
from torchvision.models import resnet50 | ||
from tvm.contrib.pt_op import compile | ||
|
||
|
||
model = resnet50().half().cuda() | ||
x = torch.rand([1, 3, 244, 244]).half().cuda() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 224? |
||
model_jit = torch.jit.trace(model, x) | ||
print(model_jit.graph) | ||
|
||
print("run torchscript...") | ||
for i in range(20): | ||
t = time.time() | ||
model_jit(x) | ||
torch.cuda.synchronize() | ||
print(time.time() - t) | ||
|
||
|
||
option = { | ||
"input_infos": [ | ||
("x", (1, 3, 244, 244)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 224? |
||
], | ||
"default_dtype": "float16", | ||
"export_dir": "pytorch_compiled", | ||
"num_outputs": 1, | ||
"tuning_n_trials": 20, # set zero to skip tuning | ||
"tuning_log_file": "tuning.log", | ||
} | ||
|
||
pytorch_tvm_module = compile(model_jit, option) | ||
torch.jit.script(pytorch_tvm_module).save("model_tvm.pt") | ||
|
||
|
||
print("Run PyTorch...") | ||
for i in range(20): | ||
t = time.time() | ||
outputs = pytorch_tvm_module.forward([x]) | ||
torch.cuda.synchronize() | ||
print(1000 * (time.time() - t)) | ||
print(outputs[0].shape) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#!/usr/bin/env python | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Test script for tf op module""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pt |
||
import tempfile | ||
import os | ||
import logging | ||
import torch | ||
import numpy as np | ||
import tvm | ||
import tvm.testing | ||
from tvm import te, relay | ||
from tvm.contrib import pt_op | ||
from tvm.contrib import graph_runtime | ||
|
||
|
||
def test_use_pt_graph_module(): | ||
"""main test function""" | ||
|
||
def build_export_graph(device): | ||
"""relay build & export graph""" | ||
x = relay.var("x", shape=(10, 5)) | ||
y = relay.var("y", shape=(1, 5)) | ||
z = relay.add(x, y) | ||
z = relay.exp(z) | ||
func = relay.Function([x, y], z) | ||
x_data = np.random.rand(10, 5).astype("float32") | ||
y_data = np.random.rand(1, 5).astype("float32") | ||
params = {"y": y_data} | ||
|
||
pt_device = torch.device(device) | ||
if pt_device.type == 'cuda': | ||
target = 'cuda' | ||
ctx = tvm.gpu(pt_device.index) | ||
else: | ||
target = 'llvm' | ||
ctx = tvm.cpu(0) | ||
|
||
graph, lib, params = relay.build(tvm.IRModule.from_expr(func), target=target, params=params) | ||
mod = graph_runtime.create(graph, lib, device=ctx) | ||
mod.set_input(**params) | ||
mod.set_input(x=x_data) | ||
mod.run() | ||
res = mod.get_output(0).asnumpy() | ||
ref_res = np.exp(y_data + x_data) | ||
tvm.testing.assert_allclose(res, ref_res, atol=1e-5, rtol=1e-5) | ||
|
||
# export to tempdir | ||
tvm_assets = ["mod.so", "graph.json", "params"] | ||
export_dir = tempfile.mkdtemp("tvm_export") | ||
lib.export_library(os.path.join(export_dir, tvm_assets[0])) | ||
with open(os.path.join(export_dir, tvm_assets[1]), 'w') as fout: | ||
fout.write(graph) | ||
with open(os.path.join(export_dir, tvm_assets[2]), 'wb') as fout: | ||
fout.write(relay.save_param_dict(params)) | ||
|
||
return export_dir | ||
|
||
def test_pt_run(device, trace=True, to_device=None): | ||
"""test add lib with Pytorch wrapper""" | ||
print('############## Test on device:', device, '#################') | ||
export_dir = build_export_graph(device) | ||
engine = pt_op.module.GraphModule(num_inputs=2, num_outputs=1).to(device) | ||
|
||
x = np.random.rand(10, 5).astype("float32") | ||
y = np.random.rand(1, 5).astype("float32") | ||
|
||
expect = np.exp(y + x) | ||
|
||
tvm_assets = ["mod.so", "graph.json", "params"] | ||
assets = [os.path.join(export_dir, i) for i in tvm_assets] | ||
engine.init((x.shape, y.shape), *assets) | ||
|
||
def get_inputs_by_device(device): | ||
if device == 'cpu': | ||
inputs = [torch.Tensor(x), torch.Tensor(y)] | ||
else: | ||
inputs = [torch.Tensor(x).cuda(), torch.Tensor(y).cuda()] | ||
return inputs | ||
|
||
outputs = engine.forward(get_inputs_by_device(device)) | ||
tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) | ||
|
||
if trace: | ||
print('################ Test trace and load #################') | ||
scripted = torch.jit.script(engine) | ||
scripted_dir = tempfile.mkdtemp("scripted") | ||
scripted_path = os.path.join(scripted_dir, 'model.pt') | ||
scripted.save(scripted_path) | ||
loaded = torch.jit.load(scripted_path) | ||
outputs = loaded.forward(get_inputs_by_device(device)) | ||
tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) | ||
del scripted | ||
del loaded | ||
|
||
if to_device: | ||
print('################ Test move from [{}] to [{}] #################'.format(device, to_device)) | ||
engine = engine.to(to_device) | ||
outputs = engine.forward(get_inputs_by_device(to_device)) | ||
tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) | ||
del engine | ||
|
||
test_pt_run(device='cuda:0', trace=True, to_device='cuda:1') | ||
test_pt_run(device='cpu', trace=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_use_pt_graph_module() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#!/usr/bin/env python | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Test script for tf op module""" | ||
import os | ||
import torch | ||
import time | ||
import numpy as np | ||
import tvm | ||
import tvm.testing | ||
from tvm.contrib.pt_op import PyTorchTVMModule, compile | ||
|
||
|
||
class Model(torch.nn.Module): | ||
def forward(self, x, y): | ||
return torch.matmul(x, y.softmax(1)) | ||
|
||
|
||
model = Model() | ||
model.cuda().half() | ||
x = torch.rand([1280, 2464, 4]).cuda().half() | ||
y = torch.rand([1280, 4, 1]).cuda().half() | ||
for i in range(20): | ||
t = time.time() | ||
o = model(x, y) | ||
torch.cuda.synchronize() | ||
print(1000 * (time.time() - t)) | ||
print(o.shape) | ||
|
||
|
||
model_jit = torch.jit.script(model) | ||
print(model_jit.graph) | ||
input_shapes = [("x", list(x.shape)), ("y", list(y.shape))] | ||
dtype = "float16" | ||
export_dir = "pytorch_compiled" | ||
|
||
|
||
mod = PyTorchTVMModule() | ||
print("Converting...") | ||
mod.from_pytorch(model_jit, input_shapes, dtype) | ||
|
||
log_file = "tuning.log" | ||
if not os.path.exists(log_file): | ||
print("Tuning...") | ||
mod.tune_tvm(log_file=log_file, n_trial=20) | ||
|
||
print("Building...") | ||
tvm_mod = mod.build_tvm(export_dir) | ||
pytorch_mod = mod.build_pytorch_op(num_inputs=2, num_outputs=1) | ||
|
||
|
||
## Or you can load from a prebuilt tvm module | ||
# mod = PyTorchTVMModule() | ||
# tvm_mod = mod.load_tvm(export_dir) | ||
# pytorch_mod = mod.build_pytorch_op(num_inputs=2, num_outputs=1, input_infos=input_shapes) | ||
|
||
|
||
print("Run TVM...") | ||
tvm_x = tvm.nd.array(x.cpu().numpy().astype(dtype), device=tvm.gpu(0)) | ||
tvm_y = tvm.nd.array(y.cpu().numpy().astype(dtype), device=tvm.gpu(0)) | ||
for i in range(20): | ||
t = time.time() | ||
tvm_mod.run(x=tvm_x, y=tvm_y) | ||
print(1000 * (time.time() - t)) | ||
tvm_output = tvm_mod.get_output(0) | ||
print(tvm_output.shape) | ||
|
||
|
||
print("Run PyTorch...") | ||
for i in range(20): | ||
t = time.time() | ||
outputs = pytorch_mod.forward([x, y]) | ||
torch.cuda.synchronize() | ||
print(1000 * (time.time() - t)) | ||
print(outputs[0].shape) | ||
|
||
|
||
class EnsembleModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layer = torch.jit.script(pytorch_mod) | ||
|
||
def forward(self, x, y, z) -> torch.Tensor: | ||
if x > 1: | ||
out = self.layer(y, z)[0] | ||
else: | ||
out = torch.ones([1280, 2464, 1]) | ||
return out | ||
|
||
|
||
print("Exporting...") | ||
scripted = torch.jit.script(EnsembleModel()) | ||
print(scripted.graph) | ||
scripted_path = os.path.join(export_dir, 'model_tvm.pt') | ||
scripted.save(scripted_path) | ||
|
||
|
||
# print(o == outputs[0]) | ||
# print(o - outputs[0]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update 'TF' or
tf
references