From 86fa3284183c071b3d7119729752eaffc4116033 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 14 Jan 2023 20:22:29 -0800 Subject: [PATCH] more complex sql chain --- docs/modules/chains/examples/sqlite.ipynb | 72 ++++++++++++++++++++- langchain/chains/sequential.py | 5 +- langchain/chains/sql_database/base.py | 78 +++++++++++++++++++++-- langchain/chains/sql_database/prompt.py | 14 ++++ langchain/prompts/base.py | 8 +++ langchain/sql_database.py | 15 ++++- 6 files changed, 184 insertions(+), 8 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index 66a4ba8ea19a2..3350fa68f0321 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -179,12 +179,82 @@ "db_chain.run(\"How many employees are there in the foobar table?\")" ] }, + { + "cell_type": "markdown", + "id": "c12ae15a", + "metadata": {}, + "source": [ + "## SQLDatabaseSequentialChain\n", + "\n", + "Chain for querying SQL database that is a sequential chain.\n", + "\n", + "The chain is as follows:\n", + "\n", + " 1. Based on the query, determine which tables to use.\n", + " 2. Based on those tables, call the normal SQL database chain.\n", + "\n", + "This is useful in cases where the number of tables in the database is large." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "e59a4740", "metadata": {}, "outputs": [], + "source": [ + "from langchain.chains.sql_database.base import SQLDatabaseSequentialChain" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "58bb49b6", + "metadata": {}, + "outputs": [], + "source": [ + "chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "95017b1a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n", + "Table names to use:\n", + "\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' 0 employees are also customers.'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"How many employees are also customers?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2998b03", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 9db4be411e3f9..a3ca88989a6a6 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -47,7 +47,10 @@ def validate_chains(cls, values: Dict) -> Dict: for chain in chains: missing_vars = set(chain.input_keys).difference(known_variables) if missing_vars: - raise ValueError(f"Missing required input keys: {missing_vars}") + raise ValueError( + f"Missing required input keys: {missing_vars}, " + f"only had {known_variables}" + ) overlapping_keys = known_variables.intersection(chain.output_keys) if overlapping_keys: raise ValueError( diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 9fbc0ab2f1caf..190d5ab5b581d 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -1,11 +1,13 @@ """Chain for interacting with SQL Database.""" -from typing import Dict, List +from __future__ import annotations + +from typing import Any, Dict, List from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.sql_database.prompt import PROMPT +from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.sql_database import SQLDatabase @@ -53,15 +55,18 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) input_text = f"{inputs[self.input_key]} \nSQLQuery:" if self.verbose: self.callback_manager.on_text(input_text) + # If not present, then defaults to None which is all tables. + table_names_to_use = inputs.get("table_names_to_use") + table_info = self.database.get_table_info(table_names=table_names_to_use) llm_inputs = { "input": input_text, "dialect": self.database.dialect, - "table_info": self.database.table_info, + "table_info": table_info, "stop": ["\nSQLResult:"], } sql_cmd = llm_chain.predict(**llm_inputs) @@ -78,3 +83,68 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: if self.verbose: self.callback_manager.on_text(final_result, color="green") return {self.output_key: final_result} + + +class SQLDatabaseSequentialChain(Chain, BaseModel): + """Chain for querying SQL database that is a sequential chain. + + The chain is as follows: + 1. Based on the query, determine which tables to use. + 2. Based on those tables, call the normal SQL database chain. + + This is useful in cases where the number of tables in the database is large. + """ + + @classmethod + def from_llm( + cls, + llm: BaseLLM, + database: SQLDatabase, + query_prompt: BasePromptTemplate = PROMPT, + decider_prompt: BasePromptTemplate = DECIDER_PROMPT, + **kwargs: Any, + ) -> SQLDatabaseSequentialChain: + """Load the necessary chains.""" + sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt) + decider_chain = LLMChain( + llm=llm, prompt=decider_prompt, output_key="table_names" + ) + return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs) + + decider_chain: LLMChain + sql_chain: SQLDatabaseChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Return the singular input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key. + + :meta private: + """ + return [self.output_key] + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + _table_names = self.sql_chain.database.get_table_names() + table_names = ", ".join(_table_names) + llm_inputs = { + "query": inputs[self.input_key], + "table_names": table_names, + } + table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs) + if self.verbose: + self.callback_manager.on_text("Table names to use:", end="\n") + self.callback_manager.on_text(str(table_names_to_use), color="yellow") + new_inputs = { + self.sql_chain.input_key: inputs[self.input_key], + "table_names_to_use": table_names_to_use, + } + return self.sql_chain(new_inputs, return_only_outputs=True) diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index 1bc9ffd9fc59d..2ae3973c5bb13 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -1,4 +1,5 @@ # flake8: noqa +from langchain.prompts.base import CommaSeparatedListOutputParser from langchain.prompts.prompt import PromptTemplate _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. @@ -17,3 +18,16 @@ PROMPT = PromptTemplate( input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE ) + +_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be neccessary to answer this question. + +Question: {query} + +Table Names: {table_names} + +Relevant Table Names:""" +DECIDER_PROMPT = PromptTemplate( + input_variables=["query", "table_names"], + template=_DECIDER_TEMPLATE, + output_parser=CommaSeparatedListOutputParser(), +) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 60cc780482610..21e5a6355fea8 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -64,6 +64,14 @@ def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" +class CommaSeparatedListOutputParser(ListOutputParser): + """Parse out comma separated lists.""" + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + return text.strip().split(", ") + + class RegexParser(BaseOutputParser, BaseModel): """Class to parse the output into a dictionary.""" diff --git a/langchain/sql_database.py b/langchain/sql_database.py index f8fc00974b3c0..d4695c0b4f810 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -50,7 +50,8 @@ def dialect(self) -> str: """Return string representation of dialect to use.""" return self._engine.dialect.name - def _get_table_names(self) -> Iterable[str]: + def get_table_names(self) -> Iterable[str]: + """Get names of tables available.""" if self._include_tables: return self._include_tables return set(self._all_tables) - set(self._ignore_tables) @@ -58,9 +59,19 @@ def _get_table_names(self) -> Iterable[str]: @property def table_info(self) -> str: """Information about all tables in the database.""" + return self.get_table_info() + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables.""" + all_table_names = self.get_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + all_table_names = table_names template = "Table '{table_name}' has columns: {columns}." tables = [] - for table_name in self._get_table_names(): + for table_name in all_table_names: columns = [] for column in self._inspector.get_columns(table_name, schema=self._schema): columns.append(f"{column['name']} ({str(column['type'])})")