-
Notifications
You must be signed in to change notification settings - Fork 8
/
get-model.sh
executable file
·45 lines (34 loc) · 1017 Bytes
/
get-model.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#! /bin/bash
# Exit on any error
set -e
function usage() {
cat <<_EOT_
Usage:
$0 model-name
Options:
model-name Specify a model name to download. (e.g. bodypix/resnet50/float/model-stride16)
Refer to https://storage.googleapis.com/tfjs-models for the available models.
_EOT_
exit 1
}
# Check args
[ -z $1 ] && usage
# Define constants & variables
BASE_URL=https://storage.googleapis.com/tfjs-models/savedmodel
MODEL_NAME=$1
DIR_NAME=$(echo ${MODEL_NAME} | tr "/" "_")
JQ=$(which jq || :)
# Verify jq is installed
[ -z ${JQ} ] && echo 'Please install "jq".' && exit 1
# Fetch model.json and weights.bin
mkdir ${DIR_NAME}
pushd ${DIR_NAME}
wget -c -nv ${BASE_URL}/${MODEL_NAME}.json -O model.json
cat model.json |
${JQ} -r ".weightsManifest | map(.paths) | flatten | @csv" |
tr "," "\n" |
xargs -I% wget -c ${BASE_URL}/${MODEL_NAME%/*}/%
popd
echo "Successfully downloaded to: ${DIR_NAME}"
# Convert to the tf_frozen_model format
tfjs_graph_converter ${DIR_NAME} ${DIR_NAME}.pb