Skip to content

Commit

Permalink
Increase stack limit for failing tflite tests. Skip TF tests which re…
Browse files Browse the repository at this point in the history
…quire TF 1.x
  • Loading branch information
Trevor Morris authored and trevor-m committed Jun 18, 2020
1 parent 7f7b8e8 commit 113094c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
1 change: 1 addition & 0 deletions docker/bash.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ ${DOCKER_BINARY} run --rm --pid=host\
-v ${WORKSPACE}:/workspace \
-v ${SCRIPT_DIR}:/docker \
-w /workspace \
--ulimit stack=16777216:16777216 \
-e "CI_BUILD_HOME=/workspace" \
-e "CI_BUILD_USER=$(id -u -n)" \
-e "CI_BUILD_UID=$(id -u)" \
Expand Down
15 changes: 15 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,9 @@ def test_forward_squeeze():
# TensorArray
# -----------
def test_tensor_array_write_read():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
pytest.skip("Needs fixing for tflite >= 1.15.0")

def run(dtype_str, infer_shape, element_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
Expand All @@ -895,6 +898,9 @@ def run(dtype_str, infer_shape, element_shape):


def test_tensor_array_scatter():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
pytest.skip("Needs fixing for tflite >= 1.15.0")

def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
Expand All @@ -921,6 +927,9 @@ def run(dtype_str, infer_shape):


def test_tensor_array_gather():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
pytest.skip("Needs fixing for tflite >= 1.15.0")

def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
Expand All @@ -937,6 +946,9 @@ def run(dtype_str, infer_shape):


def test_tensor_array_split():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
pytest.skip("Needs fixing for tflite >= 1.15.0")

def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
Expand All @@ -959,6 +971,9 @@ def run(dtype_str, infer_shape):


def test_tensor_array_concat():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
pytest.skip("Needs fixing for tflite >= 1.15.0")

def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
Expand Down
2 changes: 0 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,7 +2286,6 @@ def test_forward_qnn_mobilenet_v1_net():
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

@pytest.mark.skip("neo-ai/tvm: disabled due to segfault in CI")
def test_forward_qnn_mobilenet_v2_net():
"""Test the Quantized TFLite Mobilenet V2 model."""
# MobilenetV2
Expand Down Expand Up @@ -2439,7 +2438,6 @@ def test_forward_coco_ssd_mobilenet_v1():
# MediaPipe
# -------------

@pytest.mark.skip("neo-ai/tvm: disabled due to error in CI")
def test_forward_mediapipe_hand_landmark():
"""Test MediaPipe 2D hand landmark TF Lite model."""
# MediaPipe 2D hand landmark TF
Expand Down

0 comments on commit 113094c

Please sign in to comment.