diff --git a/projects/vdk-plugins/vdk-smarter/requirements.txt b/projects/vdk-plugins/vdk-smarter/requirements.txt index 4a1b3ecaa5..ae232aa357 100644 --- a/projects/vdk-plugins/vdk-smarter/requirements.txt +++ b/projects/vdk-plugins/vdk-smarter/requirements.txt @@ -4,5 +4,7 @@ openai pytest +pytest-httpserver vdk-core +vdk-sqlite vdk-test-utils diff --git a/projects/vdk-plugins/vdk-smarter/tests/jobs/sql-job/10_create_table.sql b/projects/vdk-plugins/vdk-smarter/tests/jobs/sql-job/10_create_table.sql new file mode 100644 index 0000000000..a1903a7703 --- /dev/null +++ b/projects/vdk-plugins/vdk-smarter/tests/jobs/sql-job/10_create_table.sql @@ -0,0 +1,2 @@ + +CREATE TABLE stocks (date text, symbol text, price real) diff --git a/projects/vdk-plugins/vdk-smarter/tests/jobs/sql-job/20_populate_table.sql b/projects/vdk-plugins/vdk-smarter/tests/jobs/sql-job/20_populate_table.sql new file mode 100644 index 0000000000..bbe7257c58 --- /dev/null +++ b/projects/vdk-plugins/vdk-smarter/tests/jobs/sql-job/20_populate_table.sql @@ -0,0 +1 @@ +INSERT INTO stocks VALUES ('2020-01-01', 'GOOG', 123.0), ('2020-01-01', 'GOOG', 123.0) diff --git a/projects/vdk-plugins/vdk-smarter/tests/test_sql_review.py b/projects/vdk-plugins/vdk-smarter/tests/test_sql_review.py new file mode 100644 index 0000000000..3dc46df44e --- /dev/null +++ b/projects/vdk-plugins/vdk-smarter/tests/test_sql_review.py @@ -0,0 +1,60 @@ +# Copyright 2021-2023 VMware, Inc. +# SPDX-License-Identifier: Apache-2.0 +import os +import pathlib +import re +from unittest import mock + +import openai +from click.testing import Result +from pytest_httpserver import HTTPServer +from vdk.plugin.smarter import openai_plugin_entry +from vdk.plugin.sqlite import sqlite_plugin +from vdk.plugin.test_utils.util_funcs import cli_assert +from vdk.plugin.test_utils.util_funcs import cli_assert_equal +from vdk.plugin.test_utils.util_funcs import CliEntryBasedTestRunner +from vdk.plugin.test_utils.util_funcs import jobs_path_from_caller_directory + +# uses the pytest tmpdir fixture - https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture + + +def test_openai_review_plugin(tmpdir, httpserver: HTTPServer): + # Mock OpenAI response + review_comment = "This is a well written SQL query." + httpserver.expect_oneshot_request(re.compile(r".*")).respond_with_json( + { + "id": "test", + "model": "foo", + "choices": [ + { + "text": " Here is the review: {'score': 4, 'review': '" + + review_comment + + "'}", + } + ], + } + ) + + # Set the OpenAI endpoint to the httpserver's uri + openai.api_base = httpserver.url_for("") + + with mock.patch.dict( + os.environ, + { + "VDK_DB_DEFAULT_TYPE": "SQLITE", + "VDK_SQLITE_FILE": str(tmpdir) + "vdk-sqlite.db", + "VDK_OPENAI_REVIEW_ENABLED": "true", + "VDK_OPENAI_MODEL": "foo", + }, + ): + runner = CliEntryBasedTestRunner(openai_plugin_entry, sqlite_plugin) + + result: Result = runner.invoke( + ["run", jobs_path_from_caller_directory("sql-job")] + ) + + cli_assert_equal(0, result) + cli_assert(review_comment in result.output, result) + + assert pathlib.Path("queries_reviews_report.md").exists() + assert review_comment in pathlib.Path("queries_reviews_report.md").read_text()