Skip to content

Commit

Permalink
Merge pull request #19 from cgerum/uma_check_compiler_attrs
Browse files Browse the repository at this point in the history
Use better function name for te_lowering and annotate current target …
  • Loading branch information
MichaelJKlaiber authored Aug 5, 2022
2 parents 41d9a84 + c250e54 commit 1878882
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/tvm/relay/backend/contrib/uma/api/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,20 @@ def _get_tensors(te_cached_func):

return args + outputs

f = tvm._ffi.get_global_func("relay.backend.LowerToTE")
te_cached_func = f(relay_prim_func)
lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE")
te_cached_func = lower_to_te(relay_prim_func)
x = _get_tensors(te_cached_func)
tir_prim_func = te.create_prim_func(x)
tir_prim_func = tir_prim_func.with_attr(
"global_symbol", relay_prim_func.attrs["global_symbol"]
)
# TODO: The target should probably come from somewhere else instead of being created here.
tir_prim_func = tir_prim_func.with_attr("target", tvm.target.Target(self.target_name))

compiler_attr = relay_prim_func.attrs["Compiler"]
target = tvm.target.Target.current()
if target.kind.name != compiler_attr:
target = tvm.target.Target(compiler_attr)

tir_prim_func = tir_prim_func.with_attr("target", target)
tir_prim_func = tir_prim_func.with_attr("relay_attrs", relay_prim_func.attrs)
return tir_prim_func

Expand Down

0 comments on commit 1878882

Please sign in to comment.