Skip to content

Commit

Permalink
Enable quantization for tflite rewriter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 304026202
  • Loading branch information
tfx-copybara authored and tensorflow-extended-team committed Mar 31, 2020
1 parent a321c4e commit f581530
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* Changed CLI behavior to create new versions of pipelines instead of
delete and create new ones when pipelines are updated for KFP. (Requires
kfp >= 0.3.0)
* Added ability to enable quantization in tflite rewriter.

### Deprecations

Expand Down
19 changes: 14 additions & 5 deletions tfx/components/trainer/rewriting/tflite_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@


def _create_tflite_converter(
saved_model_path: Text,
enable_experimental_new_converter: bool) -> tf.lite.TFLiteConverter:
saved_model_path: Text, enable_experimental_new_converter: bool,
enable_quantization: bool) -> tf.lite.TFLiteConverter:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
converter.experimental_new_converter = enable_experimental_new_converter
if enable_quantization:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
return converter


Expand All @@ -58,7 +60,8 @@ def __init__(self,
filename: Text = 'tflite',
enable_experimental_new_converter: bool = False,
copy_assets: bool = True,
copy_assets_extra: bool = True):
copy_assets_extra: bool = True,
enable_quantization: bool = False):
"""Create an instance of the TFLiteRewriter.
Args:
Expand All @@ -69,13 +72,16 @@ def __init__(self,
model directory.
copy_assets_extra: Boolean whether to copy the assets.extra directory to
the rewritten model directory.
enable_quantization: Boolean whether to enable default TFLite
quantization.
"""
# TODO(dzats): Add additional options for the TFLiteRewriter.
# TODO(b/152636072): Add support for representative_dataset.
self._name = name
self._filename = six.ensure_text(filename)
self._enable_experimental_new_converter = enable_experimental_new_converter
self._copy_assets = copy_assets
self._copy_assets_extra = copy_assets_extra
self._enable_quantization = enable_quantization

@property
def name(self) -> Text:
Expand Down Expand Up @@ -128,7 +134,10 @@ def _rewrite(self, original_model: rewriter.ModelDescription,
six.ensure_text(original_model.path), tmp_model_dir)

converter = _create_tflite_converter(
tmp_model_dir, self._enable_experimental_new_converter)
saved_model_path=tmp_model_dir,
enable_experimental_new_converter=self
._enable_experimental_new_converter,
enable_quantization=self._enable_quantization)
tflite_model = converter.convert()

output_path = os.path.join(
Expand Down
25 changes: 19 additions & 6 deletions tfx/components/trainer/rewriting/tflite_rewriter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ class TFLiteRewriterTest(tf.test.TestCase):

class ConverterMock(object):

def __init__(self):
self._convert_called = False

def convert(self):
self._convert_called = True
return 'model'

@mock.patch('tfx.components.trainer.rewriting.'
Expand All @@ -62,8 +58,16 @@ def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
dst_model_path)

tfrw = tflite_rewriter.TFLiteRewriter('myrw', 'fname', True)
tfrw = tflite_rewriter.TFLiteRewriter(
name='myrw',
filename='fname',
enable_experimental_new_converter=True)
tfrw.perform_rewrite(src_model, dst_model)

converter.assert_called_once_with(
saved_model_path=mock.ANY,
enable_experimental_new_converter=True,
enable_quantization=False)
expected_model = os.path.join(dst_model_path, 'fname')
self.assertTrue(tf.io.gfile.exists(expected_model))
with tf.io.gfile.GFile(expected_model, 'rb') as f:
Expand Down Expand Up @@ -100,8 +104,17 @@ def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
dst_model_path)

tfrw = tflite_rewriter.TFLiteRewriter('myrw', 'fname', True)
tfrw = tflite_rewriter.TFLiteRewriter(
name='myrw',
filename='fname',
enable_experimental_new_converter=True,
enable_quantization=True)
tfrw.perform_rewrite(src_model, dst_model)

converter.assert_called_once_with(
saved_model_path=mock.ANY,
enable_experimental_new_converter=True,
enable_quantization=True)
expected_model = os.path.join(dst_model_path, 'fname')
self.assertTrue(tf.io.gfile.exists(expected_model))
with tf.io.gfile.GFile(expected_model, 'rb') as f:
Expand Down

0 comments on commit f581530

Please sign in to comment.