-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
535 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""Test Script""" | ||
import logging | ||
import tempfile | ||
import sys | ||
from os import path as osp | ||
import onnx | ||
import yaml | ||
import numpy as np | ||
import tvm | ||
from tvm import relay | ||
from tvm.contrib import graph_executor | ||
from tvm.meta_schedule import TuneConfig | ||
from tvm.meta_schedule.database import JSONDatabase | ||
from tvm.meta_schedule.tune import tune_relay | ||
from tvm.target.target import Target | ||
from tvm.runtime.vm import VirtualMachine | ||
|
||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
if __name__ == "__main__": | ||
model_name = str(sys.argv[1]) | ||
if str(sys.argv[2]) == "cpu": | ||
target = Target("llvm --num-cores=8") | ||
else: | ||
target = Target("nvidia/geforce-rtx-3070") | ||
dev = tvm.cpu() if str(target.kind) == "llvm" else tvm.cuda() | ||
docs = yaml.safe_load(open("/home/zxybazh/tvm-tensorir/models/models.yaml", "r")) | ||
onnx_model = onnx.load(f"/home/zxybazh/tvm-tensorir/models/{model_name}.onnx") | ||
if model_name not in [w["name"] for w in docs]: | ||
raise Exception("Model not found!") | ||
else: | ||
for entry in docs: | ||
if entry["name"] == model_name: | ||
doc = entry | ||
break | ||
shape_dict = {} | ||
datas = {} | ||
for input in doc["input_shapes"]: | ||
shape_dict[input["name"]] = input["shape"] | ||
if "int" not in input["dtype"]: | ||
data = tvm.nd.array(np.random.randn(*input["shape"]).astype(input["dtype"]), dev) | ||
elif model_name == "bert": | ||
data = tvm.nd.array( | ||
np.random.randint(0, 30521, size=input["shape"]).astype(input["dtype"]), dev | ||
) | ||
else: | ||
assert model_name == "gpt2" # check embedding size 50257 here | ||
data = tvm.nd.array( | ||
np.random.randint(0, 50256, size=input["shape"]).astype(input["dtype"]), dev | ||
) | ||
datas[input["name"]] = data | ||
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True) | ||
|
||
def get_output(lib): | ||
module = graph_executor.GraphModule(lib["default"](dev)) | ||
module.set_input(**datas) | ||
module.run() | ||
return module.get_output(0).numpy() | ||
|
||
print("Starting to build with relay.", flush=True) | ||
|
||
# Compile without meta-scheduler for correctness check | ||
|
||
def vmobj_to_list(o, dtype="float32"): | ||
if isinstance(o, tvm.nd.NDArray): | ||
return [o] | ||
elif isinstance(o, tvm.runtime.container.ADT): | ||
result = [] | ||
for f in o: | ||
result.extend(vmobj_to_list(f, dtype)) | ||
return result | ||
else: | ||
raise RuntimeError("Unknown object type: %s" % type(o)) | ||
|
||
# print("Graph Executor failed, using virtual machine.", flush=True) | ||
with tvm.transform.PassContext(opt_level=3): | ||
vm_exec = relay.vm.compile(mod, target=target, params=params) | ||
|
||
vm = VirtualMachine(vm_exec, dev) | ||
vm.set_input("main", **datas) | ||
o = vm.run() | ||
if isinstance(o, tvm.nd.NDArray): | ||
expected_output = o.numpy() | ||
elif isinstance(o, tvm.runtime.container.ADT): | ||
result = [] | ||
for f in o: | ||
result.extend(vmobj_to_list(f)) | ||
expected_output = result[0] | ||
else: | ||
raise TypeError("Unknown object type: %s" % type(o)) | ||
|
||
logger.info("Starting to tune with meta schedule.") | ||
work_dir = "/home/zxybazh/ms-db/" | ||
rt_mod1: tvm.runtime.Module = tune_relay( | ||
mod=mod, | ||
params=params, | ||
target=target, | ||
config=TuneConfig( | ||
strategy="replay_trace", | ||
num_trials_per_iter=32, | ||
max_trials_per_task=32, | ||
max_trials_global=20000, | ||
), | ||
work_dir=work_dir, | ||
database=JSONDatabase( | ||
osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json") | ||
), | ||
) | ||
logger.info("Finished tuning with meta schedule.") | ||
|
||
# Check correctness | ||
actual_output = get_output(rt_mod1) | ||
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Test Script""" | ||
import logging | ||
import tempfile | ||
import sys | ||
from os import path as osp | ||
import onnx | ||
import yaml | ||
import numpy as np | ||
import tvm | ||
from tvm import relay | ||
from tvm.contrib import graph_executor | ||
from tvm.meta_schedule.database import JSONDatabase | ||
from tvm.meta_schedule.tune import ReplayTraceConfig, tune_relay | ||
from tvm.target.target import Target | ||
from tvm.runtime.vm import VirtualMachine | ||
|
||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
if __name__ == "__main__": | ||
model_name = str(sys.argv[1]) | ||
if str(sys.argv[2]) == "cpu": | ||
target = Target("llvm --num-cores=8") | ||
else: | ||
target = Target("nvidia/geforce-rtx-3070") | ||
dev = tvm.cpu() if str(target.kind) == "llvm" else tvm.cuda() | ||
docs = yaml.safe_load(open("/home/zxybazh/tvm-tensorir/models/models.yaml", "r")) | ||
onnx_model = onnx.load(f"/home/zxybazh/tvm-tensorir/models/{model_name}.onnx") | ||
if model_name not in [w["name"] for w in docs]: | ||
raise Exception("Model not found!") | ||
else: | ||
for entry in docs: | ||
if entry["name"] == model_name: | ||
doc = entry | ||
break | ||
shape_dict = {} | ||
datas = {} | ||
for input in doc["input_shapes"]: | ||
shape_dict[input["name"]] = input["shape"] | ||
if "int" not in input["dtype"]: | ||
data = tvm.nd.array(np.random.randn(*input["shape"]).astype(input["dtype"]), dev) | ||
elif model_name == "bert": | ||
data = tvm.nd.array( | ||
np.random.randint(0, 30521, size=input["shape"]).astype(input["dtype"]), dev | ||
) | ||
else: | ||
assert model_name == "gpt2" # check embedding size 50257 here | ||
data = tvm.nd.array( | ||
np.random.randint(0, 50256, size=input["shape"]).astype(input["dtype"]), dev | ||
) | ||
datas[input["name"]] = data | ||
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True) | ||
|
||
def get_output(lib): | ||
module = graph_executor.GraphModule(lib["default"](dev)) | ||
module.set_input(**datas) | ||
module.run() | ||
return module.get_output(0).numpy() | ||
|
||
print("Starting to build with relay.", flush=True) | ||
|
||
# Compile without meta-scheduler for correctness check | ||
|
||
def vmobj_to_list(o, dtype="float32"): | ||
if isinstance(o, tvm.nd.NDArray): | ||
return [o] | ||
elif isinstance(o, tvm.runtime.container.ADT): | ||
result = [] | ||
for f in o: | ||
result.extend(vmobj_to_list(f, dtype)) | ||
return result | ||
else: | ||
raise RuntimeError("Unknown object type: %s" % type(o)) | ||
|
||
# print("Graph Executor failed, using virtual machine.", flush=True) | ||
with tvm.transform.PassContext(opt_level=3): | ||
vm_exec = relay.vm.compile(mod, target=target, params=params) | ||
|
||
vm = VirtualMachine(vm_exec, dev) | ||
vm.set_input("main", **datas) | ||
o = vm.run() | ||
if isinstance(o, tvm.nd.NDArray): | ||
expected_output = o.numpy() | ||
else: | ||
raise TypeError("Unknown object type: %s" % type(o)) | ||
|
||
logger.info("Starting to tune with meta schedule.") | ||
work_dir = "/home/zxybazh/ms-db/" | ||
rt_mod1: tvm.runtime.Module = tune_relay( | ||
mod=mod, | ||
params=params, | ||
target=target, | ||
config=ReplayTraceConfig( | ||
num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=20000 | ||
), | ||
work_dir=work_dir, | ||
database=JSONDatabase( | ||
osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json") | ||
), | ||
) | ||
logger.info("Finished tuning with meta schedule.") | ||
|
||
# Check correctness | ||
actual_output = get_output(rt_mod1) | ||
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) |
Oops, something went wrong.