From ee129dd8c9e717421ff8cb6be1a8563a16395e2a Mon Sep 17 00:00:00 2001 From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Fri, 9 Dec 2022 17:14:30 +0400 Subject: [PATCH] [microNPU] Disable copying weights to SRAM for FullyConnected ops in CopyConstants scheduler In Ethos-U, CopyConstants scheduler currently copies weights for all operators. But in Vela, there are a number of scenarios where the weights are not buffered in SRAM, and FullyConnected case is one of them. --- .../backend/contrib/ethosu/tir/scheduler.py | 10 +++++++++- .../python/contrib/test_ethosu/test_scheduler.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index bcabe2b7c2fa..cee8f563ff7a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -132,6 +132,12 @@ def copy_constants(): def _planner(cached_func, const_dict, sch): planned = set() # type: ignore + def _is_matmul(tensor): + if tensor.name not in ["ethosu_conv2d"]: + return False + a, b = tensor.op.input_tensors[0:2] + return a.shape[1:3] == [1, 1] and b.shape[1:3] == [1, 1] + def _visit(tensor, reader, lut): if tensor not in planned: planned.add(tensor) @@ -140,7 +146,9 @@ def _visit(tensor, reader, lut): # ambiguity when encountering a scalar. is_same = [var.same_as(tensor) for var in cached_func.inputs] index = is_same.index(True) - if index in const_dict: + # Along with constants, also skip for FullyConnected to correspond + # with Vela behavior + if index in const_dict and not _is_matmul(reader): sch.cache_read(tensor, "global", [reader]) elif isinstance(tensor.op, tvm.te.ComputeOp): diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index fd1e1afa60d9..695aed0d1919 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -217,5 +217,21 @@ def test_schedule_diamond_graph(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) +def test_copy_constants_fully_connected_weights(): + """Check that MatMul-like conv2d ops do not copy weights to SRAM.""" + ifm = relay.var("IFM", shape=(1, 1, 1, 32), dtype="int8") + conv = make_ethosu_conv2d(ifm, 32, 8, (1, 1), (0, 0), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + + func, const_dict = extract_constants(func) + cached_func = lower_to_te(func) + + sch = te.create_schedule([cached_func.outputs[0].op]) + planner = copy_constants() + planner(cached_func, const_dict, sch) + assert True not in [".global" in s.op.name for s in sch.stages] + + if __name__ == "__main__": pytest.main([__file__])