diff --git a/python/tvm/relay/backend/contrib/uma/api/lower.py b/python/tvm/relay/backend/contrib/uma/api/lower.py index dc85e4a6bd904..34630949a1512 100644 --- a/python/tvm/relay/backend/contrib/uma/api/lower.py +++ b/python/tvm/relay/backend/contrib/uma/api/lower.py @@ -82,15 +82,20 @@ def _get_tensors(te_cached_func): return args + outputs - f = tvm._ffi.get_global_func("relay.backend.LowerToTE") - te_cached_func = f(relay_prim_func) + lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE") + te_cached_func = lower_to_te(relay_prim_func) x = _get_tensors(te_cached_func) tir_prim_func = te.create_prim_func(x) tir_prim_func = tir_prim_func.with_attr( "global_symbol", relay_prim_func.attrs["global_symbol"] ) - # TODO: The target should probably come from somewhere else instead of being created here. - tir_prim_func = tir_prim_func.with_attr("target", tvm.target.Target(self.target_name)) + + compiler_attr = relay_prim_func.attrs["Compiler"] + target = tvm.target.Target.current() + if target.kind.name != compiler_attr: + target = tvm.target.Target(compiler_attr) + + tir_prim_func = tir_prim_func.with_attr("target", target) tir_prim_func = tir_prim_func.with_attr("relay_attrs", relay_prim_func.attrs) return tir_prim_func