From 7157c93606b65dfc0110edbd733d80bfd537f584 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Fri, 18 Jun 2021 17:01:55 -0700 Subject: [PATCH] add metal to list of choices (#8282) --- python/tvm/driver/tvmc/runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 191f8616c405..d4d02e9f96fe 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -19,7 +19,7 @@ """ import json import logging -from typing import Optional, Dict, List, Union +from typing import Dict, List, Optional, Union import numpy as np import tvm @@ -30,12 +30,11 @@ from tvm.relay.param_dict import load_param_dict from . import common -from .model import TVMCPackage, TVMCResult from .common import TVMCException from .main import register_parser +from .model import TVMCPackage, TVMCResult from .result_utils import get_top_results - # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -51,7 +50,7 @@ def add_run_parser(subparsers): # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl"], + choices=["cpu", "cuda", "cl", "metal"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -391,6 +390,8 @@ def run_module( dev = session.cuda() elif device == "cl": dev = session.cl() + elif device == "metal": + dev = session.metal() else: assert device == "cpu" dev = session.cpu()