diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index d3559cc5f7fb..7b9538f6ce03 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -393,6 +393,7 @@ def _create_prj_conf(self, project_dir, options): if options["project_type"] == "host_driven": f.write( + "CONFIG_TIMING_FUNCTIONS=y\n" "# For RPC server C++ bindings.\n" "CONFIG_CPLUSPLUS=y\n" "CONFIG_LIB_CPLUSPLUS=y\n" diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c index 623266c0cae0..ff02b3cb1d44 100644 --- a/apps/microtvm/zephyr/template_project/src/host_driven/main.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -144,11 +145,7 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { return kTvmErrorNoError; } -#define MILLIS_TIL_EXPIRY 200 -#define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY)) -K_TIMER_DEFINE(g_microtvm_timer, /* expiry func */ NULL, /* stop func */ NULL); - -uint32_t g_microtvm_start_time; +volatile timing_t g_microtvm_start_time, g_microtvm_end_time; int g_microtvm_timer_running = 0; // Called to start system timer. @@ -161,8 +158,7 @@ tvm_crt_error_t TVMPlatformTimerStart() { #ifdef CONFIG_LED gpio_pin_set(led0_pin, LED0_PIN, 1); #endif - k_timer_start(&g_microtvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); - g_microtvm_start_time = k_cycle_get_32(); + g_microtvm_start_time = timing_counter_get(); g_microtvm_timer_running = 1; return kTvmErrorNoError; } @@ -174,43 +170,14 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { return kTvmErrorSystemErrorMask | 2; } - uint32_t stop_time = k_cycle_get_32(); #ifdef CONFIG_LED gpio_pin_set(led0_pin, LED0_PIN, 0); #endif - // compute how long the work took - uint32_t cycles_spent = stop_time - g_microtvm_start_time; - if (stop_time < g_microtvm_start_time) { - // we rolled over *at least* once, so correct the rollover it was *only* - // once, because we might still use this result - cycles_spent = ~((uint32_t)0) - (g_microtvm_start_time - stop_time); - } - - uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent); - double hw_clock_res_us = ns_spent / 1000.0; - - // need to grab time remaining *before* stopping. when stopped, this function - // always returns 0. - int32_t time_remaining_ms = k_timer_remaining_get(&g_microtvm_timer); - k_timer_stop(&g_microtvm_timer); - // check *after* stopping to prevent extra expiries on the happy path - if (time_remaining_ms < 0) { - TVMLogf("negative time remaining"); - return kTvmErrorSystemErrorMask | 3; - } - uint32_t num_expiries = k_timer_status_get(&g_microtvm_timer); - uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms); - double approx_num_cycles = - (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms); - // if we approach the limits of the HW clock datatype (uint32_t), use the - // coarse-grained timer result instead - if (approx_num_cycles > (0.5 * (~((uint32_t)0)))) { - *elapsed_time_seconds = timer_res_ms / 1000.0; - } else { - *elapsed_time_seconds = hw_clock_res_us / 1e6; - } - + g_microtvm_end_time = timing_counter_get(); + uint64_t cycles = timing_cycles_get(&g_microtvm_start_time, &g_microtvm_end_time); + uint64_t ns_spent = timing_cycles_to_ns(cycles); + *elapsed_time_seconds = ns_spent / (double)1e9; g_microtvm_timer_running = 0; return kTvmErrorNoError; } @@ -278,6 +245,11 @@ void main(void) { tvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console))); uart_rx_init(&uart_rx_rbuf, tvm_uart); + // Initialize system timing. We could stop and start it every time, but we'll + // be using it enough we should just keep it enabled. + timing_init(); + timing_start(); + // Initialize microTVM RPC server, which will receive commands from the UART and execute them. microtvm_rpc_server_t server = MicroTVMRpcServerInit(write_serial, NULL); TVMLogf("microTVM Zephyr runtime - running"); diff --git a/python/tvm/micro/testing/__init__.py b/python/tvm/micro/testing/__init__.py new file mode 100644 index 000000000000..9062f061bda3 --- /dev/null +++ b/python/tvm/micro/testing/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Allows the tools specified below to be imported directly from tvm.micro.testing""" +from .evaluation import tune_model, create_aot_session, evaluate_model_accuracy +from .utils import get_supported_boards, get_target diff --git a/python/tvm/micro/testing/aot_test_utils.py b/python/tvm/micro/testing/aot_test_utils.py index 82ac1ac68e9d..89c08395deb7 100644 --- a/python/tvm/micro/testing/aot_test_utils.py +++ b/python/tvm/micro/testing/aot_test_utils.py @@ -15,17 +15,22 @@ # specific language governing permissions and limitations # under the License. +""" +This file provides utilities for running AOT tests, especially for Corstone. + +""" + import logging import itertools import shutil import pytest -pytest.importorskip("tvm.micro") - import tvm from tvm.testing.aot import AOTTestRunner +pytest.importorskip("tvm.micro") + _LOG = logging.getLogger(__name__) @@ -97,9 +102,9 @@ def parametrize_aot_options(test): valid_combinations, ) - fn = pytest.mark.parametrize( + func = pytest.mark.parametrize( ["interface_api", "use_unpacked_api", "test_runner"], marked_combinations, )(test) - return tvm.testing.skip_if_32bit(reason="Reference system unavailable in i386 container")(fn) + return tvm.testing.skip_if_32bit(reason="Reference system unavailable in i386 container")(func) diff --git a/python/tvm/micro/testing/evaluation.py b/python/tvm/micro/testing/evaluation.py new file mode 100644 index 000000000000..c60f0fc4828e --- /dev/null +++ b/python/tvm/micro/testing/evaluation.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Provides high-level functions for instantiating and timing AOT models. Used +by autotuning tests in tests/micro, and may be used for more performance +tests in the future. + +""" + +from io import StringIO +from pathlib import Path +from contextlib import ExitStack +import tempfile + +import tvm + + +def tune_model( + platform, board, target, mod, params, num_trials, tuner_cls=tvm.autotvm.tuner.GATuner +): + """Autotunes a model with microTVM and returns a StringIO with the tuning logs""" + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target) + assert len(tasks) > 0 + assert isinstance(params, dict) + + module_loader = tvm.micro.AutoTvmModuleLoader( + template_project_dir=tvm.micro.get_microtvm_template_projects(platform), + project_options={ + f"{platform}_board": board, + "project_type": "host_driven", + }, + ) + + builder = tvm.autotvm.LocalBuilder( + n_parallel=1, + build_kwargs={"build_option": {"tir.disable_vectorize": True}}, + do_fork=False, + build_func=tvm.micro.autotvm_build_func, + runtime=tvm.relay.backend.Runtime("crt", {"system-lib": True}), + ) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) + measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) + + results = StringIO() + for task in tasks: + tuner = tuner_cls(task) + + tuner.tune( + n_trial=num_trials, + measure_option=measure_option, + callbacks=[ + tvm.autotvm.callback.log_to_file(results), + tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"), + ], + si_prefix="M", + ) + assert tuner.best_flops > 1 + + return results + + +def create_aot_session( + platform, + board, + target, + mod, + params, + build_dir=Path(tempfile.mkdtemp()), + tune_logs=None, + use_cmsis_nn=False, +): + """AOT-compiles and uploads a model to a microcontroller, and returns the RPC session""" + + executor = tvm.relay.backend.Executor("aot") + crt_runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True}) + + with ExitStack() as stack: + config = {"tir.disable_vectorize": True} + if use_cmsis_nn: + config["relay.ext.cmsisnn.options"] = {"mcpu": target.mcpu} + stack.enter_context(tvm.transform.PassContext(opt_level=3, config=config)) + if tune_logs is not None: + stack.enter_context(tvm.autotvm.apply_history_best(tune_logs)) + + lowered = tvm.relay.build( + mod, + target=target, + params=params, + runtime=crt_runtime, + executor=executor, + ) + parameter_size = len(tvm.runtime.save_param_dict(lowered.get_params())) + print(f"Model parameter size: {parameter_size}") + + # Once the project has been uploaded, we don't need to keep it + project = tvm.micro.generate_project( + str(tvm.micro.get_microtvm_template_projects(platform)), + lowered, + build_dir / "project", + { + f"{platform}_board": board, + "project_type": "host_driven", + }, + ) + project.build() + project.flash() + + return tvm.micro.Session(project.transport()) + + +# This utility functions was designed ONLY for one input / one output models +# where the outputs are confidences for different classes. +def evaluate_model_accuracy(session, aot_executor, input_data, true_labels, runs_per_sample=1): + """Evaluates an AOT-compiled model's accuracy and runtime over an RPC session. Works well + when used with create_aot_session.""" + + assert aot_executor.get_num_inputs() == 1 + assert aot_executor.get_num_outputs() == 1 + assert runs_per_sample > 0 + + predicted_labels = [] + aot_runtimes = [] + for sample in input_data: + aot_executor.get_input(0).copyfrom(sample) + result = aot_executor.module.time_evaluator("run", session.device, number=runs_per_sample)() + runtime = result.mean + output = aot_executor.get_output(0).numpy() + predicted_labels.append(output.argmax()) + aot_runtimes.append(runtime) + + num_correct = sum(u == v for u, v in zip(true_labels, predicted_labels)) + average_time = sum(aot_runtimes) / len(aot_runtimes) + accuracy = num_correct / len(predicted_labels) + return average_time, accuracy diff --git a/python/tvm/micro/testing/utils.py b/python/tvm/micro/testing/utils.py index a48c8dc3230f..820b649c74ee 100644 --- a/python/tvm/micro/testing/utils.py +++ b/python/tvm/micro/testing/utils.py @@ -17,9 +17,10 @@ """Defines the test methods used with microTVM.""" -import pathlib +from functools import lru_cache import json import logging +from pathlib import Path import tarfile import time from typing import Union @@ -32,7 +33,19 @@ TIMEOUT_SEC = 10 -def check_tune_log(log_path: Union[pathlib.Path, str]): +@lru_cache(maxsize=None) +def get_supported_boards(platform: str): + template = Path(tvm.micro.get_microtvm_template_projects(platform)) + with open(template / "boards.json") as f: + return json.load(f) + + +def get_target(platform: str, board: str): + model = get_supported_boards(platform)[board]["model"] + return str(tvm.target.target.micro(model)) + + +def check_tune_log(log_path: Union[Path, str]): """Read the tuning log and check each result.""" with open(log_path, "r") as f: lines = f.readlines() @@ -76,7 +89,7 @@ def _read_line(transport, timeout_sec: int) -> str: return data.decode(encoding="utf-8") -def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[pathlib.Path, str]) -> int: +def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[Path, str]) -> int: """Extract an MLF archive file and read workspace size from metadata file.""" workspace_size = 0 diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index d7c2adaa8606..47bdab5828b9 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -67,6 +67,7 @@ def test_something(): import copyreg import ctypes import functools +import hashlib import itertools import logging import os @@ -77,7 +78,7 @@ def test_something(): import time from pathlib import Path -from typing import Optional, Callable, Union, List +from typing import Optional, Callable, Union, List, Tuple import pytest import numpy as np @@ -90,6 +91,7 @@ def test_something(): from tvm.contrib import nvcc, cudnn import tvm.contrib.hexagon._ci_env_check as hexagon +from tvm.driver.tvmc.frontends import load_model from tvm.error import TVMError @@ -1661,6 +1663,47 @@ def install_request_hook(depth: int) -> None: request_hook.init() +def fetch_model_from_url( + url: str, + model_format: str, + sha256: str, +) -> Tuple[tvm.ir.module.IRModule, dict]: + """Testing function to fetch a model from a URL and return it as a Relay + model. Downloaded files are cached for future re-use. + + Parameters + ---------- + url : str + The URL or list of URLs to try downloading the model from. + + model_format: str + The file extension of the model format used. + + sha256 : str + The sha256 hex hash to compare the downloaded model against. + + Returns + ------- + (mod, params) : object + The Relay representation of the downloaded model. + """ + + rel_path = f"model_{sha256}.{model_format}" + file = tvm.contrib.download.download_testdata(url, rel_path, overwrite=False) + + # Check SHA-256 hash + file_hash = hashlib.sha256() + with open(file, "rb") as f: + for block in iter(lambda: f.read(2**24), b""): + file_hash.update(block) + + if file_hash.hexdigest() != sha256: + raise FileNotFoundError("SHA-256 hash for model does not match") + + tvmc_model = load_model(file, model_format) + return tvmc_model.mod, tvmc_model.params + + def main(): test_file = inspect.getsourcefile(sys._getframe(1)) sys.exit(pytest.main([test_file] + sys.argv[1:])) diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index d26b047e8121..37b64433b23e 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -140,7 +140,6 @@ "tests/micro/testdata/mnist/digit-2.jpg", "tests/micro/testdata/mnist/digit-9.jpg", "tests/micro/testdata/mnist/mnist-8.onnx", - "tests/micro/testdata/kws/yes_no.tflite", # microTVM Zephyr runtime "apps/microtvm/zephyr/template_project/CMakeLists.txt.template", "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm", diff --git a/tests/micro/arduino/test_utils.py b/tests/micro/arduino/test_utils.py index c107d5b1febf..20e7d9e75001 100644 --- a/tests/micro/arduino/test_utils.py +++ b/tests/micro/arduino/test_utils.py @@ -25,7 +25,7 @@ from tvm.micro import project from tvm import relay from tvm.relay.backend import Executor, Runtime - +from tvm.testing.utils import fetch_model_from_url TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino")) @@ -66,20 +66,12 @@ def make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir): model = ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} - with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f: - tflite_model_buf = f.read() - - # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 - try: - import tflite.Model - - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - except AttributeError: - import tflite - - tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + mod, params = fetch_model_from_url( + url="https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite", + model_format="tflite", + sha256="09e5e2a9dfb2d8ed78802bf18ce297bff54281a66ca18e0c23d69ca14f822a83", + ) - mod, params = relay.frontend.from_tflite(tflite_model) target = tvm.target.target.micro(model) runtime = Runtime("crt") executor = Executor("aot", {"unpacked-api": True}) diff --git a/tests/micro/common/conftest.py b/tests/micro/common/conftest.py index 3fbfdbcbc81d..10dda8774bca 100644 --- a/tests/micro/common/conftest.py +++ b/tests/micro/common/conftest.py @@ -21,11 +21,17 @@ def pytest_addoption(parser): + parser.addoption( + "--platform", + required=True, + choices=["arduino", "zephyr"], + help="Platform to run tests with", + ) parser.addoption( "--board", required=True, choices=list(ARDUINO_BOARDS.keys()) + list(ZEPHYR_BOARDS.keys()), - help="microTVM boards for tests.", + help="microTVM boards for tests", ) parser.addoption( "--test-build-only", @@ -34,6 +40,11 @@ def pytest_addoption(parser): ) +@pytest.fixture +def platform(request): + return request.config.getoption("--platform") + + @pytest.fixture def board(request): return request.config.getoption("--board") diff --git a/tests/micro/common/test_autotune.py b/tests/micro/common/test_autotune.py new file mode 100644 index 000000000000..37836563a069 --- /dev/null +++ b/tests/micro/common/test_autotune.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from io import StringIO +import json +from pathlib import Path +import sys +import tempfile +from typing import Union + +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.micro.testing +from tvm.testing.utils import fetch_model_from_url + +TUNING_RUNS_PER_OPERATOR = 2 + + +@pytest.mark.requires_hardware +@tvm.testing.requires_micro +def test_kws_autotune_workflow(platform, board, tmp_path): + mod, params = fetch_model_from_url( + url="https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite", + model_format="tflite", + sha256="09e5e2a9dfb2d8ed78802bf18ce297bff54281a66ca18e0c23d69ca14f822a83", + ) + target = tvm.micro.testing.get_target(platform, board) + + str_io_logs = tvm.micro.testing.tune_model( + platform, board, target, mod, params, TUNING_RUNS_PER_OPERATOR + ) + assert isinstance(str_io_logs, StringIO) + + str_logs = str_io_logs.getvalue().rstrip().split("\n") + logs = list(map(json.loads, str_logs)) + assert len(logs) == 2 * TUNING_RUNS_PER_OPERATOR # Two operators + + # Check we tested both operators + op_names = list(map(lambda x: x["input"][1], logs)) + assert op_names[0] == op_names[1] == "dense_nopack.x86" + assert op_names[2] == op_names[3] == "dense_pack.x86" + + # Make sure we tested different code. != does deep comparison in Python 3 + assert logs[0]["config"]["index"] != logs[1]["config"]["index"] + assert logs[0]["config"]["entity"] != logs[1]["config"]["entity"] + assert logs[2]["config"]["index"] != logs[3]["config"]["index"] + assert logs[2]["config"]["entity"] != logs[3]["config"]["entity"] + + # Compile the best model with AOT and connect to it + with tvm.micro.testing.create_aot_session( + platform, + board, + target, + mod, + params, + build_dir=tmp_path, + tune_logs=str_io_logs, + ) as session: + aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor()) + + samples = ( + np.random.randint(low=-127, high=128, size=(1, 1960), dtype=np.int8) for x in range(3) + ) + + labels = [0, 0, 0] + + # Validate perforance across random runs + time, acc = tvm.micro.testing.evaluate_model_accuracy( + session, aot_executor, samples, labels, runs_per_sample=20 + ) + # `time` is the average time taken to execute model inference on the + # device, measured in seconds. It does not include the time to upload + # the input data via RPC. On slow boards like the Arduino Due, time + # is around 0.12 (120 ms), so this gives us plenty of buffer. + assert time < 1 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py index 24d0213b7754..096e12393d43 100644 --- a/tests/micro/common/test_tvmc.py +++ b/tests/micro/common/test_tvmc.py @@ -29,9 +29,6 @@ import tvm.testing from tvm.contrib.download import download_testdata -from ..zephyr.test_utils import ZEPHYR_BOARDS -from ..arduino.test_utils import ARDUINO_BOARDS - TVMC_COMMAND = [sys.executable, "-m", "tvm.driver.tvmc"] MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite" @@ -47,22 +44,8 @@ def _run_tvmc(cmd_args: list, *args, **kwargs): return subprocess.check_call(cmd_args_list, *args, **kwargs) -def _get_target_and_platform(board: str): - if board in ZEPHYR_BOARDS.keys(): - target_model = ZEPHYR_BOARDS[board] - platform = "zephyr" - elif board in ARDUINO_BOARDS.keys(): - target_model = ARDUINO_BOARDS[board] - platform = "arduino" - else: - raise ValueError(f"Board {board} is not supported.") - - target = tvm.target.target.micro(target_model) - return str(target), platform - - @tvm.testing.requires_micro -def test_tvmc_exist(board): +def test_tvmc_exist(platform, board): cmd_result = _run_tvmc(["micro", "-h"]) assert cmd_result == 0 @@ -72,8 +55,8 @@ def test_tvmc_exist(board): "output_dir,", [pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())], ) -def test_tvmc_model_build_only(board, output_dir): - target, platform = _get_target_and_platform(board) +def test_tvmc_model_build_only(platform, board, output_dir): + target = tvm.micro.testing.get_target(platform, board) if not os.path.isabs(output_dir): out_dir_temp = os.path.abspath(output_dir) @@ -138,8 +121,8 @@ def test_tvmc_model_build_only(board, output_dir): "output_dir,", [pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())], ) -def test_tvmc_model_run(board, output_dir): - target, platform = _get_target_and_platform(board) +def test_tvmc_model_run(platform, board, output_dir): + target = tvm.micro.testing.get_target(platform, board) if not os.path.isabs(output_dir): out_dir_temp = os.path.abspath(output_dir) diff --git a/tests/micro/testdata/kws/yes_no.tflite b/tests/micro/testdata/kws/yes_no.tflite deleted file mode 100644 index 4f533dac8405..000000000000 Binary files a/tests/micro/testdata/kws/yes_no.tflite and /dev/null differ diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 2274c6ca6b28..e057883776bb 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -38,8 +38,8 @@ run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build- run_pytest ctypes python-microtvm-stm32 tests/micro/stm32 # Common Tests -run_pytest ctypes python-microtvm-common-qemu_x86 tests/micro/common --board=qemu_x86 -run_pytest ctypes python-microtvm-common-due tests/micro/common --test-build-only --board=due +run_pytest ctypes python-microtvm-common-qemu_x86 tests/micro/common --platform=zephyr --board=qemu_x86 +run_pytest ctypes python-microtvm-common-due tests/micro/common --platform=arduino --test-build-only --board=due # Tutorials python3 gallery/how_to/work_with_microtvm/micro_tflite.py