Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more complex sql chain #619

Merged
merged 1 commit into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion docs/modules/chains/examples/sqlite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,82 @@
"db_chain.run(\"How many employees are there in the foobar table?\")"
]
},
{
"cell_type": "markdown",
"id": "c12ae15a",
"metadata": {},
"source": [
"## SQLDatabaseSequentialChain\n",
"\n",
"Chain for querying SQL database that is a sequential chain.\n",
"\n",
"The chain is as follows:\n",
"\n",
" 1. Based on the query, determine which tables to use.\n",
" 2. Based on those tables, call the normal SQL database chain.\n",
"\n",
"This is useful in cases where the number of tables in the database is large."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "e59a4740",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.sql_database.base import SQLDatabaseSequentialChain"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "58bb49b6",
"metadata": {},
"outputs": [],
"source": [
"chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "95017b1a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n",
"Table names to use:\n",
"\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' 0 employees are also customers.'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"How many employees are also customers?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2998b03",
"metadata": {},
"outputs": [],
"source": []
}
],
Expand Down
5 changes: 4 additions & 1 deletion langchain/chains/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def validate_chains(cls, values: Dict) -> Dict:
for chain in chains:
missing_vars = set(chain.input_keys).difference(known_variables)
if missing_vars:
raise ValueError(f"Missing required input keys: {missing_vars}")
raise ValueError(
f"Missing required input keys: {missing_vars}, "
f"only had {known_variables}"
)
overlapping_keys = known_variables.intersection(chain.output_keys)
if overlapping_keys:
raise ValueError(
Expand Down
78 changes: 74 additions & 4 deletions langchain/chains/sql_database/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Chain for interacting with SQL Database."""
from typing import Dict, List
from __future__ import annotations

from typing import Any, Dict, List

from pydantic import BaseModel, Extra

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.sql_database import SQLDatabase
Expand Down Expand Up @@ -53,15 +55,18 @@ def output_keys(self) -> List[str]:
"""
return [self.output_key]

def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
if self.verbose:
self.callback_manager.on_text(input_text)
# If not present, then defaults to None which is all tables.
table_names_to_use = inputs.get("table_names_to_use")
table_info = self.database.get_table_info(table_names=table_names_to_use)
llm_inputs = {
"input": input_text,
"dialect": self.database.dialect,
"table_info": self.database.table_info,
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
sql_cmd = llm_chain.predict(**llm_inputs)
Expand All @@ -78,3 +83,68 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
if self.verbose:
self.callback_manager.on_text(final_result, color="green")
return {self.output_key: final_result}


class SQLDatabaseSequentialChain(Chain, BaseModel):
"""Chain for querying SQL database that is a sequential chain.

The chain is as follows:
1. Based on the query, determine which tables to use.
2. Based on those tables, call the normal SQL database chain.

This is useful in cases where the number of tables in the database is large.
"""

@classmethod
def from_llm(
cls,
llm: BaseLLM,
database: SQLDatabase,
query_prompt: BasePromptTemplate = PROMPT,
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
**kwargs: Any,
) -> SQLDatabaseSequentialChain:
"""Load the necessary chains."""
sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt)
decider_chain = LLMChain(
llm=llm, prompt=decider_prompt, output_key="table_names"
)
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)

decider_chain: LLMChain
sql_chain: SQLDatabaseChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:

@property
def input_keys(self) -> List[str]:
"""Return the singular input key.

:meta private:
"""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Return the singular output key.

:meta private:
"""
return [self.output_key]

def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_table_names = self.sql_chain.database.get_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],
"table_names": table_names,
}
table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs)
if self.verbose:
self.callback_manager.on_text("Table names to use:", end="\n")
self.callback_manager.on_text(str(table_names_to_use), color="yellow")
new_inputs = {
self.sql_chain.input_key: inputs[self.input_key],
"table_names_to_use": table_names_to_use,
}
return self.sql_chain(new_inputs, return_only_outputs=True)
14 changes: 14 additions & 0 deletions langchain/chains/sql_database/prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa
from langchain.prompts.base import CommaSeparatedListOutputParser
from langchain.prompts.prompt import PromptTemplate

_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Expand All @@ -17,3 +18,16 @@
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)

_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be neccessary to answer this question.

Question: {query}

Table Names: {table_names}

Relevant Table Names:"""
DECIDER_PROMPT = PromptTemplate(
input_variables=["query", "table_names"],
template=_DECIDER_TEMPLATE,
output_parser=CommaSeparatedListOutputParser(),
)
8 changes: 8 additions & 0 deletions langchain/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""


class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse out comma separated lists."""

def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")


class RegexParser(BaseOutputParser, BaseModel):
"""Class to parse the output into a dictionary."""

Expand Down
15 changes: 13 additions & 2 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,28 @@ def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name

def _get_table_names(self) -> Iterable[str]:
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return set(self._all_tables) - set(self._ignore_tables)

@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()

def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
all_table_names = self.get_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
template = "Table '{table_name}' has columns: {columns}."
tables = []
for table_name in self._get_table_names():
for table_name in all_table_names:
columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(f"{column['name']} ({str(column['type'])})")
Expand Down