From 58521d040e00025e2872f306901d0db63adeb7c3 Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Thu, 22 Aug 2024 23:06:57 +1000 Subject: [PATCH] added CLI options to tflite converter to upgrade older weights files --- impsy/tests/test_data.py | 4 ++-- impsy/tflite_converter.py | 26 ++++++++++++++++++++------ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/impsy/tests/test_data.py b/impsy/tests/test_data.py index 73ca1f8..6f5a5da 100644 --- a/impsy/tests/test_data.py +++ b/impsy/tests/test_data.py @@ -120,10 +120,10 @@ def test_weights_to_model_file(trained_model, dimension, tmp_path_factory, mdrnn weights_file = trained_model["weights_file"] print(f"Weights file: {weights_file}") test_dir = tmp_path_factory.mktemp("model_file") - model_file_name = tflite_converter.weights_file_to_model_file(weights_file, mdrnn_size, dimension, test_dir) + model_file_name = tflite_converter.weights_file_to_model_file(weights_file, mdrnn_size, dimension) print(f"File returned: {model_file_name}") assert os.path.exists(model_file_name) - os.remove(model_file_name) + # os.remove(model_file_name) def test_model_file_to_tflite(trained_model): diff --git a/impsy/tflite_converter.py b/impsy/tflite_converter.py index 1cb7ef9..8452e05 100644 --- a/impsy/tflite_converter.py +++ b/impsy/tflite_converter.py @@ -70,7 +70,7 @@ def config_to_tflite(config_path): model_to_tflite(net.model, model_path) -def weights_file_to_model_file(weights_file, model_size, dimension, location): +def weights_file_to_model_file(weights_file, model_size, dimension): """Constructs a model from a given weights file and saves as a .keras inference model.""" import impsy.mdrnn as mdrnn @@ -84,12 +84,26 @@ def weights_file_to_model_file(weights_file, model_size, dimension, location): ) inference_model.load_model(model_file=weights_file) model_name = inference_model.model_name() - keras_filename = Path(location) / f"{model_name}.keras" - inference_model.model.save(keras_filename) - return keras_filename + keras_file_path = Path(weights_file).with_suffix(".keras") + inference_model.model.save(keras_file_path) + return keras_file_path @click.command(name="convert-tflite") -def convert_tflite(): +@click.option('--model', '-m', help='Path to a .keras model or .h5 weights') +@click.option('--dimension', '-d', type=int, help='Dimension (only needed for h5 files)') +@click.option('--size', '-s', help="Size, one of xs, s, m, l, (only needed for h5 files)") +def convert_tflite(model, dimension, size): """Convert existing IMPSY model to tflite format.""" - config_to_tflite("config.toml") + if model is None: + config_to_tflite("config.toml") + elif Path(model).suffix == ".keras": + # it's a keras file + model_file_to_tflite(model) + elif Path(model).suffix == ".h5": + # it's an h5 file + if dimension is not None and size is not None: + model_file = weights_file_to_model_file(model, size, dimension) + model_file_to_tflite(model_file) + else: + click.secho("You need to specify a dimension and size to convert an h5 file.")