forked from Qihoo360/tensornet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfigure.sh
executable file
·113 lines (83 loc) · 2.57 KB
/
configure.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/bin/sh
readonly WORKSPACE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
readonly THIS_FILE="${WORKSPACE_DIR}/$(basename "${BASH_SOURCE[0]}")"
pushd $WORKSPACE_DIR > /dev/null
readonly PYTHON_PATH="$(which python)"
# global parameter
OPENMPI_PATH=""
function _simple_help()
{
echo -e "
${THIS_FILE} arguments
arguments:
--openmpi_path the path of your openmpi installed
--help help info
"
return 0
}
function simple_eval_param()
{
while [[ $# -gt 0 ]]; do
key="$1"
case $key in
--openmpi_path)
OPENMPI_PATH="$2" && shift
;;
* |--help | -h)
_simple_help
exit 1
;;
esac
shift
done
echo "${OPENMPI_PATH}"
if [ "x" == "x${OPENMPI_PATH}" ]; then
echo "please specify where openmpi installed"
_simple_help
exit 1
fi
return 0
}
function check_tf_version()
{
echo "checking tensorflow version installed..."
local tf_version=$(python -c "import tensorflow as tf; print(tf.version.VERSION)")
local tf_major_version=`echo ${tf_version} | awk -F'.' 'BEGIN{OFS="."}{print $1 OFS $2}'`
if [[ "x${tf_major_version}" != "x2.2" ]] && [[ "x${tf_major_version}" != "x2.3" ]]; then
echo "tensorflow version is ${tf_version}, please use 2.2.0 ~ 2.3.0 instead"
exit 1
fi
echo "tensorflow version installed is ${tf_version}"
}
function link_mpi_thirdparty()
{
echo "using openmpi include path:$OPENMPI_PATH/include"
echo "using openmpi lib path:$OPENMPI_PATH/lib"
rm -rf thirdparty/openmpi/include
ln -s ${OPENMPI_PATH}/include thirdparty/openmpi/
rm -rf thirdparty/openmpi/lib
ln -s ${OPENMPI_PATH}/lib thirdparty/openmpi/
}
function link_tf_thirdparty()
{
local tf_include_path=$(python -c "import tensorflow as tf;print(tf.sysconfig.get_include())")
local tf_lib_path=$(python -c "import tensorflow as tf;print(tf.sysconfig.get_lib())")
echo "using tensorflow lib path:${tf_lib_path}"
rm thirdparty/tensorflow/lib/*
mkdir -p thirdparty/tensorflow/lib/
ln -s ${tf_lib_path}/lib* thirdparty/tensorflow/lib/
ln -sf ${tf_lib_path}/python/_pywrap_tensorflow_internal.so thirdparty/tensorflow/lib/lib_pywrap_tensorflow_internal.so
ln -sf ${tf_include_path} thirdparty/tensorflow/
}
function main()
{
echo "using python:${PYTHON_PATH}"
simple_eval_param $@
check_tf_version
link_mpi_thirdparty
link_tf_thirdparty
echo "configure done"
}
main $@
popd > /dev/null
exit 0