Skip to content

Commit

Permalink
Merge pull request #117 from joeyhng/master
Browse files Browse the repository at this point in the history
Model Export for MediaPipe inference
  • Loading branch information
XericZephyr authored Oct 25, 2019
2 parents c643cb6 + e9c07c1 commit e6f6bf6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ model.
* [Evaluation and Inference](#evaluation-and-inference)
* [Create Your Own Dataset Files](#create-your-own-dataset-files)
* [Training without this Starter Code](#training-without-this-starter-code)
* [Export Your Model for MediaPipe Inference](#export-your-model-for-mediapipe-inference)
* [More Documents](#more-documents)
* [About This Project](#about-this-project)

Expand Down Expand Up @@ -321,6 +322,16 @@ and the following for the inference code:
num examples processed: 8192 elapsed seconds: 14.85
```

## Export Your Model for MediaPipe Inference
To run inference with your model in [MediaPipe inference
demo](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m#steps-to-run-the-youtube-8m-inference-graph-with-the-yt8m-dataset), you need to export your checkpoint to a SavedModel.

Example command:
```sh
python export_model_mediapipe.py --checkpoint_file ~/yt8m/models/frame/sample_model/inference_model/segment_inference_model --output_dir /tmp/mediapipe/saved_model/
```


## Create Your Own Dataset Files

You can create your dataset files from your own videos. Our
Expand Down
60 changes: 60 additions & 0 deletions export_model_mediapipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Lint as: python3
import numpy as np
import tensorflow as tf
from tensorflow import app
from tensorflow import flags

FLAGS = flags.FLAGS


def main(unused_argv):
# Get the input tensor names to be replaced.
tf.reset_default_graph()
meta_graph_location = FLAGS.checkpoint_file + ".meta"
tf.train.import_meta_graph(meta_graph_location, clear_devices=True)

input_tensor_name = tf.get_collection("input_batch_raw")[0].name
num_frames_tensor_name = tf.get_collection("num_frames")[0].name

# Create output graph.
saver = tf.train.Saver()
tf.reset_default_graph()

input_feature_placeholder = tf.placeholder(
tf.float32, shape=(None, None, 1152))
num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1))

saver = tf.train.import_meta_graph(
meta_graph_location,
input_map={
input_tensor_name: input_feature_placeholder,
num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1)
},
clear_devices=True)
predictions_tensor = tf.get_collection("predictions")[0]

with tf.Session() as sess:
print("restoring variables from " + FLAGS.checkpoint_file)
saver.restore(sess, FLAGS.checkpoint_file)
tf.saved_model.simple_save(
sess,
FLAGS.output_dir,
inputs={'rgb_and_audio': input_feature_placeholder,
'num_frames': num_frames_placeholder},
outputs={'predictions': predictions_tensor})

# Try running inference.
predictions = sess.run(
[predictions_tensor],
feed_dict={
input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32),
num_frames_placeholder: np.array([[7]], dtype=np.int32)})
print('Test inference:', predictions)

print('Model saved to ', FLAGS.output_dir)


if __name__ == '__main__':
flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.')
flags.DEFINE_string('output_dir', None, 'SavedModel output directory.')
app.run(main)

0 comments on commit e6f6bf6

Please sign in to comment.