From 71b8b6f5fd953c4da74157c54a46bb92adb37b7e Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 15 Apr 2024 18:39:03 +0000 Subject: [PATCH] add fixture for redis db index --- server/config.py | 1 + server/nlp/embeddings.py | 6 ++++-- server_tests/conftest.py | 18 +++++++++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/server/config.py b/server/config.py index 14c062c60..1bf1736f4 100644 --- a/server/config.py +++ b/server/config.py @@ -37,6 +37,7 @@ def _get_config_option(name: str, default_value: str | None = None) -> str: "DATABASE_URL", "postgresql://postgres:password@database/pigeondb" ) REDIS_URL = _get_config_option("REDIS_URL", "redis") +REDIS_DB_INDEX = _get_config_option("REDIS_DB_INDEX", "0") FLASK_RUN_PORT = 2010 DEBUG = True diff --git a/server/nlp/embeddings.py b/server/nlp/embeddings.py index 525e5b03d..073857fcf 100644 --- a/server/nlp/embeddings.py +++ b/server/nlp/embeddings.py @@ -17,14 +17,16 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query -from server.config import REDIS_URL, RedisDocument +from server.config import REDIS_DB_INDEX, REDIS_URL, RedisDocument cwd = os.path.dirname(__file__) VECTOR_DIMENSION = 1536 # load redis client -client = redis.Redis(host=REDIS_URL, port=6379, decode_responses=True) +client = redis.Redis( + host=REDIS_URL, port=6379, decode_responses=True, db=int(REDIS_DB_INDEX) +) # load corpus # with open('corpus.json', 'r') as f: diff --git a/server_tests/conftest.py b/server_tests/conftest.py index d6b010920..83997a59d 100644 --- a/server_tests/conftest.py +++ b/server_tests/conftest.py @@ -3,6 +3,7 @@ import psycopg2 import pytest +import redis from apiflask import APIFlask from flask.testing import FlaskClient, FlaskCliRunner from psycopg2 import sql @@ -41,8 +42,23 @@ def db_url(db_name="pigeondb_test"): @pytest.fixture(scope="session") -def app(db_url: str): +def redis_db_index(): + """Yields test db index for Redis. + + Flushes test db if it already exists. + """ + test_db_index = 1 + client = redis.Redis(host="localhost", port=6379, db=test_db_index) + client.flushdb() + client.close() + + yield test_db_index + + +@pytest.fixture(scope="session") +def app(db_url: str, redis_db_index: int): os.environ["DATABASE_URL"] = db_url + os.environ["REDIS_DB_INDEX"] = str(redis_db_index) app = create_app() app.config.update(