Skip to content

Commit

Permalink
[RELAY] Move frontend utils (apache#5345)
Browse files Browse the repository at this point in the history
* [RELAY] Move frontend utils

The util file currently under frontend is used from
outside of frontend (in qnn/op/legalizations). This suggests
that the file should be pushed up to a higher level.

The benefit from this change is that importing qnn no longer
also imports all the frontends.

* Inline get_scalar_from_constant

Change-Id: I1cc64e9ecb0eadb6ac0f7b62e6ea174644af4ad4

* Remove util.py from Relay

Change-Id: If9cd7cf3fc0bd1861a3a9b5604f338e084d8db96

* Shorten functions

Change-Id: Ieb537d82e6ee52421ff05a90cd00a03679ffebf2

* Line length

Change-Id: I1d216b7e73a060c4f118f5da50ce58b18eba907f
  • Loading branch information
mbaret authored and dhruvaray committed Apr 28, 2020
1 parent ae9a581 commit 4c06e2e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 35 deletions.
12 changes: 11 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from .. import op as _op
from .. import qnn as _qnn
from ... import nd as _nd
from .util import get_scalar_from_constant
from .common import ExprTable
from .common import infer_shape as _infer_shape

Expand Down Expand Up @@ -2341,6 +2340,17 @@ def get_expr(self, input_tensor_idx):
def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))


def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
"Expr is not a constant scalar."
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return np.asscalar(value)


def build_str_map(obj):
"""Build string map of TFLite enum int value
Expand Down
33 changes: 0 additions & 33 deletions python/tvm/relay/frontend/util.py

This file was deleted.

11 changes: 10 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import tvm
from tvm import relay
import numpy as np
from .. import op as reg
from ...frontend.util import get_scalar_from_constant

#################################################
# Register the functions for different operators.
Expand Down Expand Up @@ -54,6 +54,15 @@ def qnn_dense_legalize(attrs, inputs, types):
# Helper functions.
###################

def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, relay.Constant) and not expr.data.shape, \
"Expr is not a constant scalar."
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return np.asscalar(value)

# Helper function for lowering in the abscence of fast Int8 arithmetic units.
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
""" Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
Expand Down

0 comments on commit 4c06e2e

Please sign in to comment.