Skip to content

Commit

Permalink
Fix sqlalchemy warnings when running tests (#733)
Browse files Browse the repository at this point in the history
This has been bugging me when running my own tests that call langchain
methods :P
  • Loading branch information
amosjyng authored Jan 25, 2023
1 parent bd0bf4e commit fa6826e
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 7 deletions.
3 changes: 1 addition & 2 deletions langchain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, declarative_base

from langchain.schema import Generation

Expand Down
2 changes: 1 addition & 1 deletion langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run(self, command: str) -> str:
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.connect() as connection:
with self._engine.begin() as connection:
if self._schema is not None:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.exec_driver_sql(command)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/llms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test base LLM functionality."""
from sqlalchemy import Column, Integer, Sequence, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base

import langchain
from langchain.cache import InMemoryCache, SQLAlchemyCache
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_sql_database_run() -> None:
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn:
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "select user_name from user where user_id = 13"
Expand All @@ -54,7 +54,7 @@ def test_sql_database_run_update() -> None:
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn:
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "update user set user_name='Updated' where user_id = 13"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_sql_database_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_sql_database_run() -> None:
engine = create_engine("duckdb:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn:
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine, schema="schema_a")
command = 'select user_name from "user" where user_id = 13'
Expand Down

0 comments on commit fa6826e

Please sign in to comment.