forked from kubeflow/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use conditionals and add test for code search (kubeflow#291)
* Fix model export, loss function, and add some manual tests. Fix Model export to support computing code embeddings: Fix kubeflow#260 * The previous exported model was always using the embeddings trained for the search query. * But we need to be able to compute embedding vectors for both the query and code. * To support this we add a new input feature "embed_code" and conditional ops. The exported model uses the value of the embed_code feature to determine whether to treat the inputs as a query string or code and computes the embeddings appropriately. * Originally based on kubeflow#233 by @activatedgeek Loss function improvements * See kubeflow#259 for a long discussion about different loss functions. * @activatedgeek was experimenting with different loss functions in kubeflow#233 and this pulls in some of those changes. Add manual tests * Related to kubeflow#258 * We add a smoke test for T2T steps so we can catch bugs in the code. * We also add a smoke test for serving the model with TFServing. * We add a sanity check to ensure we get different values for the same input based on which embeddings we are computing. Change Problem/Model name * Register the problem github_function_docstring with a different name to distinguish it from the version inside the Tensor2Tensor library. * * Skip the test when running under prow because its a manual test. * Fix some lint errors. * * Fix lint and skip tests. * Fix lint. * * Fix lint * Revert loss function changes; we can do that in a follow on PR. * * Run generate_data as part of the test rather than reusing a cached vocab and processed input file. * Modify SimilarityTransformer so we can overwrite the number of shards used easily to facilitate testing. * Comment out py-test for now.
- Loading branch information
Showing
20 changed files
with
127,338 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Developer guide for the code search example | ||
|
||
This doc is intended for folks looking to contribute to the example. | ||
|
||
## Testing | ||
|
||
We currently have tests that can be run manually to test the code. | ||
We hope to get these integrated into our CI system soon. | ||
|
||
### T2T Test | ||
|
||
The test code_search/src/code_search/t2t/similarity_transformer_test.py | ||
can be used to test | ||
|
||
* Training | ||
* Evaluation | ||
* Model Export | ||
|
||
The test can be run as follows | ||
|
||
``` | ||
cd code_search/src | ||
python3 -m code_searcch.t2t.similarity_transformer_export_test | ||
``` | ||
The test just runs the relevant T2T steps and verifies they succeeds. No additional | ||
checks are executed. | ||
|
||
|
||
### TF Serving test | ||
|
||
code_search/src/code_search/nmslib/cli/embed_query_test.py | ||
|
||
|
||
Can be used to test generating predictions using TFServing. | ||
|
||
The test assumes the TFServing is running in a docker container | ||
|
||
You can start TFServing as follows | ||
|
||
``` | ||
./code_search/nmslib/cli/start_test_server.sh | ||
``` | ||
|
||
You can then run the test | ||
|
||
``` | ||
export PYTHONPATH=${EXAMPLES_REPO/code_search/src:${PYTHONPATH} | ||
python3 -m embed_query_test | ||
``` | ||
|
||
The test verifies that the code can successfully generate embeddings using TFServing. | ||
|
||
The test verifies that different embeddings are computed for the query and the code. | ||
|
||
**start_test_server.sh** relies on a model stored in **code_search/src/code_search/t2t/** | ||
A new model can be produced by running **similarity_transformer_export_test**. The unittest | ||
will export the model to a temporary directory. You can then copy that model to the test_data | ||
directory. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
code_search/src/code_search/dataflow/transforms/function_embeddings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
code_search/src/code_search/nmslib/cli/embed_query_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# coding=utf-8 | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Test embedding query using TFServing. | ||
This is a manual/E2E test that assumes TFServing is running externally | ||
(e.g. Docker container or K8s pod). | ||
The script start_test_server.sh can be used to start a Docker container | ||
when running locally. | ||
To run TFServing we need a model. start_test_server.sh will use a model | ||
in ../../t2t/test_data/model | ||
code_search must be a top level Python package. | ||
requires host machine has tensorflow_model_server executable available | ||
""" | ||
|
||
# TODO(jlewi): Starting the test seems very slow. I wonder if this is because | ||
# tensor2tensor is loading a bunch of models and if maybe we can skip that. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import datetime | ||
import logging | ||
import os | ||
import unittest | ||
import tensorflow as tf | ||
|
||
import numpy as np | ||
|
||
from code_search.nmslib.cli import start_search_server | ||
|
||
start = datetime.datetime.now() | ||
|
||
FLAGS = tf.flags.FLAGS | ||
|
||
PROBLEM_NAME = "kf_github_function_docstring" | ||
|
||
class TestEmbedQuery(unittest.TestCase): | ||
|
||
@unittest.skipIf(os.getenv("PROW_JOB_ID"), "Manual test not run on prow") | ||
def test_embed(self): | ||
"""Test that we can embed the search query string via tf.serving. | ||
This test assumes the model is running as an external process in TensorFlow | ||
serving. | ||
The external process can be started a variety of ways e.g. subprocess, | ||
kubernetes, or docker container. | ||
The script start_test_server.sh can be used to start TFServing in | ||
docker container. | ||
""" | ||
# Directory containing the vocabulary. | ||
test_data_dir = os.path.abspath( | ||
os.path.join(os.path.dirname(__file__), "..", "..", "t2t", "test_data")) | ||
# 8501 should be REST port | ||
server = os.getenv("TEST_SERVER", "localhost:8501") | ||
|
||
# Model name matches the subdirectory in TF Serving's model Directory | ||
# containing models. | ||
model_name = "test_model_20181031" | ||
serving_url = "http://{0}/v1/models/{1}:predict".format(server, model_name) | ||
query = "Write to GCS" | ||
query_encoder = start_search_server.build_query_encoder(PROBLEM_NAME, | ||
test_data_dir) | ||
code_encoder = start_search_server.build_query_encoder(PROBLEM_NAME, | ||
test_data_dir, | ||
embed_code=True) | ||
|
||
query_result = start_search_server.embed_query(query_encoder, serving_url, query) | ||
code_result = start_search_server.embed_query(code_encoder, serving_url, query) | ||
|
||
# As a sanity check ensure the vectors aren't equal | ||
q_vec = np.array(query_result) | ||
q_vec = q_vec/np.sqrt(np.dot(q_vec, q_vec)) | ||
c_vec = np.array(code_result) | ||
c_vec = c_vec/np.sqrt(np.dot(c_vec, c_vec)) | ||
|
||
dist = np.dot(q_vec, c_vec) | ||
self.assertNotAlmostEqual(1, dist) | ||
logging.info("Done") | ||
|
||
if __name__ == "__main__": | ||
logging.getLogger().setLevel(logging.INFO) | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
code_search/src/code_search/nmslib/cli/start_test_server.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/bin/bash | ||
# | ||
# A simple script for starting TFServing locally in a docker container. | ||
# This allows us to test sending predictions to the model. | ||
set -ex | ||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" | ||
MODELS_DIR="$( cd "${DIR}/../../t2t/test_data/" >/dev/null && pwd)" | ||
|
||
MODEL_NAME=test_model_20181031 | ||
|
||
if [ ! -d ${MODELS_DIR}/${MODEL_NAME} ]; then | ||
echo Missing directory ${MODELS_DIR}/${MODEL_NAME} | ||
exit 1 | ||
fi | ||
|
||
set +e | ||
docker rm -f cs_serving_test | ||
set -e | ||
|
||
# TODO(jlewi): Is there anyway to cause TF Serving to load all models in | ||
# MODELS_DIR and not have to specify the environment variable MODEL_NAME | ||
docker run --rm --name=cs_serving_test -p 8500:8500 -p 8501:8501 \ | ||
-v "${MODELS_DIR}:/models" \ | ||
-e MODEL_NAME="${MODEL_NAME}" \ | ||
tensorflow/serving | ||
# Tail the logs | ||
docker logs -f cs_serving_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.