Skip to content

Commit

Permalink
added CLI options to tflite converter to upgrade older weights files
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Aug 22, 2024
1 parent a807b1b commit 58521d0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
4 changes: 2 additions & 2 deletions impsy/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def test_weights_to_model_file(trained_model, dimension, tmp_path_factory, mdrnn
weights_file = trained_model["weights_file"]
print(f"Weights file: {weights_file}")
test_dir = tmp_path_factory.mktemp("model_file")
model_file_name = tflite_converter.weights_file_to_model_file(weights_file, mdrnn_size, dimension, test_dir)
model_file_name = tflite_converter.weights_file_to_model_file(weights_file, mdrnn_size, dimension)
print(f"File returned: {model_file_name}")
assert os.path.exists(model_file_name)
os.remove(model_file_name)
# os.remove(model_file_name)


def test_model_file_to_tflite(trained_model):
Expand Down
26 changes: 20 additions & 6 deletions impsy/tflite_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def config_to_tflite(config_path):
model_to_tflite(net.model, model_path)


def weights_file_to_model_file(weights_file, model_size, dimension, location):
def weights_file_to_model_file(weights_file, model_size, dimension):
"""Constructs a model from a given weights file and saves as a .keras inference model."""
import impsy.mdrnn as mdrnn

Expand All @@ -84,12 +84,26 @@ def weights_file_to_model_file(weights_file, model_size, dimension, location):
)
inference_model.load_model(model_file=weights_file)
model_name = inference_model.model_name()
keras_filename = Path(location) / f"{model_name}.keras"
inference_model.model.save(keras_filename)
return keras_filename
keras_file_path = Path(weights_file).with_suffix(".keras")
inference_model.model.save(keras_file_path)
return keras_file_path


@click.command(name="convert-tflite")
def convert_tflite():
@click.option('--model', '-m', help='Path to a .keras model or .h5 weights')
@click.option('--dimension', '-d', type=int, help='Dimension (only needed for h5 files)')
@click.option('--size', '-s', help="Size, one of xs, s, m, l, (only needed for h5 files)")
def convert_tflite(model, dimension, size):
"""Convert existing IMPSY model to tflite format."""
config_to_tflite("config.toml")
if model is None:
config_to_tflite("config.toml")
elif Path(model).suffix == ".keras":
# it's a keras file
model_file_to_tflite(model)
elif Path(model).suffix == ".h5":
# it's an h5 file
if dimension is not None and size is not None:
model_file = weights_file_to_model_file(model, size, dimension)
model_file_to_tflite(model_file)
else:
click.secho("You need to specify a dimension and size to convert an h5 file.")

0 comments on commit 58521d0

Please sign in to comment.