diff --git a/templates/Makefile b/templates/Makefile
index 7460b2ee51623..445235c22c841 100644
--- a/templates/Makefile
+++ b/templates/Makefile
@@ -1,2 +1,8 @@
lint lint_diff:
- poetry run ruff .
+ poetry run poe lint
+
+test:
+ poetry run poe test
+
+format:
+ poetry run poe format
diff --git a/templates/anthropic-iterative-search/anthropic_iterative_search/__init__.py b/templates/anthropic-iterative-search/anthropic_iterative_search/__init__.py
index 81cc8c187fa4f..ef5463d1979c1 100644
--- a/templates/anthropic-iterative-search/anthropic_iterative_search/__init__.py
+++ b/templates/anthropic-iterative-search/anthropic_iterative_search/__init__.py
@@ -1,7 +1,7 @@
from langchain.schema.runnable import ConfigurableField
+from .chain import chain
from .retriever_agent import executor
-from .chain import chain
final_chain = chain.configurable_alternatives(
ConfigurableField(id="chain"),
diff --git a/templates/anthropic-iterative-search/anthropic_iterative_search/chain.py b/templates/anthropic-iterative-search/anthropic_iterative_search/chain.py
index 6ccbbda9343fe..ba2a7d2bcd22f 100644
--- a/templates/anthropic-iterative-search/anthropic_iterative_search/chain.py
+++ b/templates/anthropic-iterative-search/anthropic_iterative_search/chain.py
@@ -1,5 +1,5 @@
-from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatAnthropic
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from .prompts import answer_prompt
diff --git a/templates/anthropic-iterative-search/anthropic_iterative_search/output_parser.py b/templates/anthropic-iterative-search/anthropic_iterative_search/output_parser.py
index f04837c928800..f1e981d761060 100644
--- a/templates/anthropic-iterative-search/anthropic_iterative_search/output_parser.py
+++ b/templates/anthropic-iterative-search/anthropic_iterative_search/output_parser.py
@@ -1,6 +1,7 @@
-from langchain.schema.agent import AgentAction, AgentFinish
import re
+from langchain.schema.agent import AgentAction, AgentFinish
+
from .agent_scratchpad import _format_docs
@@ -14,18 +15,23 @@ def extract_between_tags(tag: str, string: str, strip: bool = True) -> str:
# Only return the first one
return ext_list[0]
+
def parse_output(outputs):
partial_completion = outputs["partial_completion"]
steps = outputs["intermediate_steps"]
- search_query = extract_between_tags('search_query', partial_completion + '')
+ search_query = extract_between_tags(
+ "search_query", partial_completion + ""
+ )
if search_query is None:
docs = []
str_output = ""
for action, observation in steps:
docs.extend(observation)
str_output += action.log
- str_output += '' + _format_docs(observation)
+ str_output += "" + _format_docs(observation)
str_output += partial_completion
return AgentFinish({"docs": docs, "output": str_output}, log=partial_completion)
else:
- return AgentAction(tool="search", tool_input=search_query, log=partial_completion)
+ return AgentAction(
+ tool="search", tool_input=search_query, log=partial_completion
+ )
diff --git a/templates/anthropic-iterative-search/anthropic_iterative_search/prompts.py b/templates/anthropic-iterative-search/anthropic_iterative_search/prompts.py
index e58cd4bf4d79a..fe46574cb4483 100644
--- a/templates/anthropic-iterative-search/anthropic_iterative_search/prompts.py
+++ b/templates/anthropic-iterative-search/anthropic_iterative_search/prompts.py
@@ -2,6 +2,6 @@
After each call to the Search Engine Tool, reflect briefly inside tags about whether you now have enough information to answer, or whether more information is needed. If you have all the relevant information, write it in tags, WITHOUT actually answering the question. Otherwise, issue a new search.
-Here is the user's question: {query} Remind yourself to make short queries in your scratchpad as you plan out your strategy."""
+Here is the user's question: {query} Remind yourself to make short queries in your scratchpad as you plan out your strategy.""" # noqa: E501
-answer_prompt = "Here is a user query: {query}. Here is some relevant information: {information}. Please answer the question using the relevant information."
+answer_prompt = "Here is a user query: {query}. Here is some relevant information: {information}. Please answer the question using the relevant information." # noqa: E501
diff --git a/templates/anthropic-iterative-search/anthropic_iterative_search/retriever.py b/templates/anthropic-iterative-search/anthropic_iterative_search/retriever.py
index d565dbbd9e295..5377e65be32cb 100644
--- a/templates/anthropic-iterative-search/anthropic_iterative_search/retriever.py
+++ b/templates/anthropic-iterative-search/anthropic_iterative_search/retriever.py
@@ -3,13 +3,14 @@
# This is used to tell the model how to best use the retriever.
-retriever_description = """You will be asked a question by a human user. You have access to the following tool to help answer the question. Search Engine Tool * The search engine will exclusively search over Wikipedia for pages similar to your query. It returns for each page its title and full page content. Use this tool if you want to get up-to-date and comprehensive information on a topic to help answer queries. Queries should be as atomic as possible -- they only need to address one part of the user's question. For example, if the user's query is "what is the color of a basketball?", your search query should be "basketball". Here's another example: if the user's question is "Who created the first neural network?", your first query should be "neural network". As you can see, these queries are quite short. Think keywords, not phrases. * At any time, you can make a call to the search engine using the following syntax: query_word. * You'll then get results back in tags."""
+retriever_description = """You will be asked a question by a human user. You have access to the following tool to help answer the question. Search Engine Tool * The search engine will exclusively search over Wikipedia for pages similar to your query. It returns for each page its title and full page content. Use this tool if you want to get up-to-date and comprehensive information on a topic to help answer queries. Queries should be as atomic as possible -- they only need to address one part of the user's question. For example, if the user's query is "what is the color of a basketball?", your search query should be "basketball". Here's another example: if the user's question is "Who created the first neural network?", your first query should be "neural network". As you can see, these queries are quite short. Think keywords, not phrases. * At any time, you can make a call to the search engine using the following syntax: query_word. * You'll then get results back in tags.""" # noqa: E501
retriever = WikipediaRetriever()
# This should be the same as the function name below
RETRIEVER_TOOL_NAME = "search"
+
@tool
def search(query):
"""Search with the retriever."""
diff --git a/templates/anthropic-iterative-search/anthropic_iterative_search/retriever_agent.py b/templates/anthropic-iterative-search/anthropic_iterative_search/retriever_agent.py
index 139c631aabe4d..1ee01f34c76a5 100644
--- a/templates/anthropic-iterative-search/anthropic_iterative_search/retriever_agent.py
+++ b/templates/anthropic-iterative-search/anthropic_iterative_search/retriever_agent.py
@@ -1,13 +1,13 @@
+from langchain.agents import AgentExecutor
from langchain.chat_models import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
-from langchain.schema.runnable import RunnablePassthrough, RunnableMap
from langchain.schema.output_parser import StrOutputParser
-from langchain.agents import AgentExecutor
+from langchain.schema.runnable import RunnableMap, RunnablePassthrough
-from .retriever import search, RETRIEVER_TOOL_NAME, retriever_description
-from .prompts import retrieval_prompt
from .agent_scratchpad import format_agent_scratchpad
from .output_parser import parse_output
+from .prompts import retrieval_prompt
+from .retriever import retriever_description, search
prompt = ChatPromptTemplate.from_messages([
("user", retrieval_prompt),
diff --git a/templates/anthropic-iterative-search/main.py b/templates/anthropic-iterative-search/main.py
index 5bb812a465cee..27b7aa1aa6afa 100644
--- a/templates/anthropic-iterative-search/main.py
+++ b/templates/anthropic-iterative-search/main.py
@@ -1,6 +1,12 @@
-from anthropic_iterative_search import final_chain
-
+from anthropic_iterative_search import final_chain
if __name__ == "__main__":
- query = "Which movie came out first: Oppenheimer, or Are You There God It's Me Margaret?"
- print(final_chain.with_config(configurable={"chain": "retrieve"}).invoke({"query": query}))
+ query = (
+ "Which movie came out first: Oppenheimer, or "
+ "Are You There God It's Me Margaret?"
+ )
+ print(
+ final_chain.with_config(configurable={"chain": "retrieve"}).invoke(
+ {"query": query}
+ )
+ )
diff --git a/templates/cassandra-entomology-rag/cassandra_entomology_rag/__init__.py b/templates/cassandra-entomology-rag/cassandra_entomology_rag/__init__.py
index 590c9465a8e69..6af8acd87df0f 100644
--- a/templates/cassandra-entomology-rag/cassandra_entomology_rag/__init__.py
+++ b/templates/cassandra-entomology-rag/cassandra_entomology_rag/__init__.py
@@ -1,14 +1,12 @@
import os
import cassio
-
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
-from langchain.vectorstores import Cassandra
from langchain.prompts import ChatPromptTemplate
-from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
-
+from langchain.schema.runnable import RunnablePassthrough
+from langchain.vectorstores import Cassandra
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
if use_cassandra:
diff --git a/templates/cassandra-entomology-rag/cassandra_entomology_rag/cassandra_cluster_init.py b/templates/cassandra-entomology-rag/cassandra_entomology_rag/cassandra_cluster_init.py
index e59e5bf6dfc61..9c4ce5b9b662f 100644
--- a/templates/cassandra-entomology-rag/cassandra_entomology_rag/cassandra_cluster_init.py
+++ b/templates/cassandra-entomology-rag/cassandra_entomology_rag/cassandra_cluster_init.py
@@ -1,13 +1,13 @@
import os
-from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
+from cassandra.cluster import Cluster
def get_cassandra_connection():
contact_points = [
cp.strip()
- for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(',')
+ for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(",")
if cp.strip()
]
CASSANDRA_KEYSPACE = os.environ["CASSANDRA_KEYSPACE"]
@@ -22,6 +22,8 @@ def get_cassandra_connection():
else:
auth_provider = None
- c_cluster = Cluster(contact_points if contact_points else None, auth_provider=auth_provider)
+ c_cluster = Cluster(
+ contact_points if contact_points else None, auth_provider=auth_provider
+ )
session = c_cluster.connect()
return (session, CASSANDRA_KEYSPACE)
diff --git a/templates/cassandra-entomology-rag/setup.py b/templates/cassandra-entomology-rag/setup.py
index a84ce59a2f2db..5747d67216222 100644
--- a/templates/cassandra-entomology-rag/setup.py
+++ b/templates/cassandra-entomology-rag/setup.py
@@ -1,14 +1,13 @@
import os
import cassio
-
-from langchain.vectorstores import Cassandra
from langchain.embeddings import OpenAIEmbeddings
-
+from langchain.vectorstores import Cassandra
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
if use_cassandra:
from cassandra_entomology_rag.cassandra_cluster_init import get_cassandra_connection
+
session, keyspace = get_cassandra_connection()
cassio.init(
session=session,
@@ -22,7 +21,7 @@
)
-if __name__ == '__main__':
+if __name__ == "__main__":
embeddings = OpenAIEmbeddings()
vector_store = Cassandra(
session=None,
@@ -32,16 +31,13 @@
)
#
lines = [
- l.strip()
- for l in open("sources.txt").readlines()
- if l.strip()
- if l[0] != "#"
+ line.strip()
+ for line in open("sources.txt").readlines()
+ if line.strip()
+ if line[0] != "#"
]
# deterministic IDs to prevent duplicates on multiple runs
- ids = [
- "_".join(l.split(" ")[:2]).lower().replace(":", "")
- for l in lines
- ]
+ ids = ["_".join(line.split(" ")[:2]).lower().replace(":", "") for line in lines]
#
vector_store.add_texts(texts=lines, ids=ids)
print(f"Done ({len(lines)} lines inserted).")
diff --git a/templates/cassandra-synonym-caching/cassandra_synonym_caching/__init__.py b/templates/cassandra-synonym-caching/cassandra_synonym_caching/__init__.py
index 412d5a4e427e7..575ea10879742 100644
--- a/templates/cassandra-synonym-caching/cassandra_synonym_caching/__init__.py
+++ b/templates/cassandra-synonym-caching/cassandra_synonym_caching/__init__.py
@@ -1,13 +1,12 @@
import os
import cassio
-
import langchain
-from langchain.schema import BaseMessage
-from langchain.prompts import ChatPromptTemplate
+from langchain.cache import CassandraCache
from langchain.chat_models import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate
+from langchain.schema import BaseMessage
from langchain.schema.runnable import RunnableLambda
-from langchain.cache import CassandraCache
use_cassandra = int(os.environ.get("USE_CASSANDRA_CLUSTER", "0"))
if use_cassandra:
diff --git a/templates/cassandra-synonym-caching/cassandra_synonym_caching/cassandra_cluster_init.py b/templates/cassandra-synonym-caching/cassandra_synonym_caching/cassandra_cluster_init.py
index e59e5bf6dfc61..9c4ce5b9b662f 100644
--- a/templates/cassandra-synonym-caching/cassandra_synonym_caching/cassandra_cluster_init.py
+++ b/templates/cassandra-synonym-caching/cassandra_synonym_caching/cassandra_cluster_init.py
@@ -1,13 +1,13 @@
import os
-from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
+from cassandra.cluster import Cluster
def get_cassandra_connection():
contact_points = [
cp.strip()
- for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(',')
+ for cp in os.environ.get("CASSANDRA_CONTACT_POINTS", "").split(",")
if cp.strip()
]
CASSANDRA_KEYSPACE = os.environ["CASSANDRA_KEYSPACE"]
@@ -22,6 +22,8 @@ def get_cassandra_connection():
else:
auth_provider = None
- c_cluster = Cluster(contact_points if contact_points else None, auth_provider=auth_provider)
+ c_cluster = Cluster(
+ contact_points if contact_points else None, auth_provider=auth_provider
+ )
session = c_cluster.connect()
return (session, CASSANDRA_KEYSPACE)
diff --git a/templates/csv-agent/csv_agent/agent.py b/templates/csv-agent/csv_agent/agent.py
index 9badaaeee0a68..9d04f36694fa3 100644
--- a/templates/csv-agent/csv_agent/agent.py
+++ b/templates/csv-agent/csv_agent/agent.py
@@ -1,24 +1,25 @@
-from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
-from langchain_experimental.tools import PythonAstREPLTool
+from pathlib import Path
+
import pandas as pd
+from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.chat_models import ChatOpenAI
-from langsmith import Client
-from langchain.smith import RunEvalConfig, run_on_dataset
-from pydantic import BaseModel, Field
from langchain.embeddings import OpenAIEmbeddings
-from langchain.vectorstores import FAISS
+from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.retriever import create_retriever_tool
-from pathlib import Path
+from langchain.vectorstores import FAISS
+from langchain_experimental.tools import PythonAstREPLTool
+from pydantic import BaseModel, Field
MAIN_DIR = Path(__file__).parents[1]
-pd.set_option('display.max_rows', 20)
-pd.set_option('display.max_columns', 20)
+pd.set_option("display.max_rows", 20)
+pd.set_option("display.max_columns", 20)
embedding_model = OpenAIEmbeddings()
vectorstore = FAISS.load_local(MAIN_DIR / "titanic_data", embedding_model)
-retriever_tool = create_retriever_tool(vectorstore.as_retriever(), "person_name_search", "Search for a person by name")
+retriever_tool = create_retriever_tool(
+ vectorstore.as_retriever(), "person_name_search", "Search for a person by name"
+)
TEMPLATE = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
@@ -41,8 +42,7 @@
Who has id 320
Use `python_repl` since even though the question is about a person, you don't know their name so you can't include it.
-"""
-
+""" # noqa: E501
class PythonInputs(BaseModel):
@@ -52,15 +52,24 @@ class PythonInputs(BaseModel):
df = pd.read_csv("titanic.csv")
template = TEMPLATE.format(dhead=df.head().to_markdown())
-prompt = ChatPromptTemplate.from_messages([
- ("system", template),
- MessagesPlaceholder(variable_name="agent_scratchpad"),
- ("human", "{input}")
-])
-
-repl = PythonAstREPLTool(locals={"df": df}, name="python_repl",
- description="Runs code and returns the output of the final line",
- args_schema=PythonInputs)
+prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", template),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ("human", "{input}"),
+ ]
+)
+
+repl = PythonAstREPLTool(
+ locals={"df": df},
+ name="python_repl",
+ description="Runs code and returns the output of the final line",
+ args_schema=PythonInputs,
+)
tools = [repl, retriever_tool]
-agent = OpenAIFunctionsAgent(llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools)
-agent_executor = AgentExecutor(agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate")
+agent = OpenAIFunctionsAgent(
+ llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools
+)
+agent_executor = AgentExecutor(
+ agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate"
+)
diff --git a/templates/csv-agent/ingest.py b/templates/csv-agent/ingest.py
index 5bb6784ea8a77..7da5942f172c0 100644
--- a/templates/csv-agent/ingest.py
+++ b/templates/csv-agent/ingest.py
@@ -1,5 +1,4 @@
from langchain.document_loaders import CSVLoader
-from langchain.tools.retriever import create_retriever_tool
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import FAISS
diff --git a/templates/elastic-query-generator/elastic_query_generator/chain.py b/templates/elastic-query-generator/elastic_query_generator/chain.py
index 917dbd71da3bf..aa4575ed02753 100644
--- a/templates/elastic-query-generator/elastic_query_generator/chain.py
+++ b/templates/elastic-query-generator/elastic_query_generator/chain.py
@@ -1,11 +1,12 @@
import os
+from pathlib import Path
+
+from elasticsearch import Elasticsearch
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.json import SimpleJsonOutputParser
-from elasticsearch import Elasticsearch
-from pathlib import Path
-from .prompts import DSL_PROMPT
from .elastic_index_info import get_indices_infos
+from .prompts import DSL_PROMPT
es_host = os.environ["ELASTIC_SEARCH_SERVER"]
es_password = os.environ["ELASTIC_PASSWORD"]
diff --git a/templates/elastic-query-generator/elastic_query_generator/elastic_index_info.py b/templates/elastic-query-generator/elastic_query_generator/elastic_index_info.py
index 8d059afdd8f54..db328389eac75 100644
--- a/templates/elastic-query-generator/elastic_query_generator/elastic_index_info.py
+++ b/templates/elastic-query-generator/elastic_query_generator/elastic_index_info.py
@@ -1,5 +1,6 @@
from typing import List
+
def _list_indices(database, include_indices=None, ignore_indices=None) -> List[str]:
all_indices = [
index["index"] for index in database.cat.indices(format="json")
diff --git a/templates/elastic-query-generator/elastic_query_generator/prompts.py b/templates/elastic-query-generator/elastic_query_generator/prompts.py
index 1df5c1e78b3e5..861a5874593bf 100644
--- a/templates/elastic-query-generator/elastic_query_generator/prompts.py
+++ b/templates/elastic-query-generator/elastic_query_generator/prompts.py
@@ -16,6 +16,6 @@
Question: Question here
ESQuery: Elasticsearch Query formatted as json
-"""
+""" # noqa: E501
DSL_PROMPT = PromptTemplate.from_template(DEFAULT_DSL_TEMPLATE + PROMPT_SUFFIX)
diff --git a/templates/elastic-query-generator/ingest.py b/templates/elastic-query-generator/ingest.py
index 528e5324909a0..757a2c6a1bafb 100644
--- a/templates/elastic-query-generator/ingest.py
+++ b/templates/elastic-query-generator/ingest.py
@@ -1,4 +1,5 @@
import os
+
from elasticsearch import Elasticsearch
es_host = os.environ["ELASTIC_SEARCH_SERVER"]
diff --git a/templates/elastic-query-generator/main.py b/templates/elastic-query-generator/main.py
index 30025fe88414a..4f848b6e88ac6 100644
--- a/templates/elastic-query-generator/main.py
+++ b/templates/elastic-query-generator/main.py
@@ -1,5 +1,4 @@
from elastic_query_generator.chain import chain
-
if __name__ == "__main__":
print(chain.invoke({"input": "how many customers named Carol"}))
diff --git a/templates/extraction-openai-functions/extraction_openai_functions/chain.py b/templates/extraction-openai-functions/extraction_openai_functions/chain.py
index 6e9dcc530cef8..ccf3c876848c5 100644
--- a/templates/extraction-openai-functions/extraction_openai_functions/chain.py
+++ b/templates/extraction-openai-functions/extraction_openai_functions/chain.py
@@ -1,40 +1,46 @@
-from langchain.pydantic_v1 import BaseModel
+import json
from typing import List, Optional
+
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
+from langchain.pydantic_v1 import BaseModel
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
-from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
-import json
-
template = """A article will be passed to you. Extract from it all papers that are mentioned by this article.
Do not extract the name of the article itself. If no papers are mentioned that's fine - you don't need to extract any! Just return an empty list.
-Do not make up or guess ANY extra information. Only extract what exactly is in the text."""
+Do not make up or guess ANY extra information. Only extract what exactly is in the text.""" # noqa: E501
+
+prompt = ChatPromptTemplate.from_messages([("system", template), ("human", "{input}")])
-prompt = ChatPromptTemplate.from_messages([
- ("system", template),
- ("human", "{input}")
-])
# Function output schema
class Paper(BaseModel):
"""Information about papers mentioned."""
+
title: str
author: Optional[str]
class Info(BaseModel):
"""Information to extract"""
+
papers: List[Paper]
+
# Function definition
model = ChatOpenAI()
function = [convert_pydantic_to_openai_function(Info)]
-chain = prompt | model.bind(
- functions=function, function_call={"name": "Info"}
-) | (lambda x: json.loads(x.additional_kwargs['function_call']['arguments'])['papers'])
+chain = (
+ prompt
+ | model.bind(functions=function, function_call={"name": "Info"})
+ | (
+ lambda x: json.loads(x.additional_kwargs["function_call"]["arguments"])[
+ "papers"
+ ]
+ )
+)
# chain = prompt | model.bind(
# functions=function, function_call={"name": "Info"}
diff --git a/templates/hyde/hyde/chain.py b/templates/hyde/hyde/chain.py
index 99021257b08db..915c994d6c3fe 100644
--- a/templates/hyde/hyde/chain.py
+++ b/templates/hyde/hyde/chain.py
@@ -1,14 +1,15 @@
-from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Chroma
+
from hyde.prompts import hyde_prompt
# Example for document loading (from url), splitting, and creating vectostore
-'''
+"""
# Load
from langchain.document_loaders import WebBaseLoader
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
@@ -25,13 +26,13 @@
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
-'''
+"""
# Embed a single document as a test
vectorstore = Chroma.from_texts(
["harrison worked at kensho"],
collection_name="rag-chroma",
- embedding=OpenAIEmbeddings()
+ embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
@@ -48,11 +49,18 @@
# RAG chain
chain = (
- RunnableParallel({
- # Configure the input, pass it the prompt, pass that to the model, and then the result to the retriever
- "context": {"input": RunnablePassthrough()} | hyde_prompt | model | StrOutputParser() | retriever,
- "question": RunnablePassthrough()
- })
+ RunnableParallel(
+ {
+ # Configure the input, pass it the prompt, pass that to the model,
+ # and then the result to the retriever
+ "context": {"input": RunnablePassthrough()}
+ | hyde_prompt
+ | model
+ | StrOutputParser()
+ | retriever,
+ "question": RunnablePassthrough(),
+ }
+ )
| prompt
| model
| StrOutputParser()
diff --git a/templates/hyde/hyde/prompts.py b/templates/hyde/hyde/prompts.py
index 3ffd47816331c..b6d2f0881b05b 100644
--- a/templates/hyde/hyde/prompts.py
+++ b/templates/hyde/hyde/prompts.py
@@ -7,7 +7,7 @@
Passage:"""
sci_fact_template = """Please write a scientific paper passage to support/refute the claim
Claim: {input}
-Passage:"""
+Passage:""" # noqa: E501
fiqa_template = """Please write a financial article passage to answer the question
Question: {input}
Passage:"""
diff --git a/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py b/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py
index acedfe4aac0e7..4868140b91622 100644
--- a/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py
+++ b/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py
@@ -1,13 +1,13 @@
-
from typing import List, Optional
+from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
+from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
-from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
-from langchain.chains.openai_functions import create_structured_output_chain
+
try:
from pydantic.v1.main import BaseModel, Field
except ImportError:
@@ -27,15 +27,18 @@
cypher_llm = ChatOpenAI(model_name="gpt-4", temperature=0.0)
qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.0)
+
# Extract entities from text
class Entities(BaseModel):
"""Identifying information about entities."""
names: List[str] = Field(
...,
- description="All the person, organization, or business entities that appear in the text",
+ description="All the person, organization, or business entities that "
+ "appear in the text",
)
+
prompt = ChatPromptTemplate.from_messages(
[
(
@@ -44,11 +47,13 @@ class Entities(BaseModel):
),
(
"human",
- "Use the given format to extract information from the following input: {question}",
+ "Use the given format to extract information from the following "
+ "input: {question}",
),
]
)
+
# Fulltext index query
def map_to_database(entities: Entities) -> Optional[str]:
result = ""
@@ -56,16 +61,16 @@ def map_to_database(entities: Entities) -> Optional[str]:
response = graph.query(
"CALL db.index.fulltext.queryNodes('entity', $entity + '*', {limit:1})"
" YIELD node,score RETURN node.name AS result",
- {"entity":entity})
+ {"entity": entity},
+ )
try:
result += f"{entity} maps to {response[0]['result']} in database\n"
except IndexError:
pass
return result
-entity_chain = create_structured_output_chain(
- Entities, qa_llm, prompt
-)
+
+entity_chain = create_structured_output_chain(Entities, qa_llm, prompt)
# Generate Cypher statement based on natural language input
cypher_template = """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
@@ -73,7 +78,7 @@ def map_to_database(entities: Entities) -> Optional[str]:
Entities in the question map to the following database values:
{entities_list}
Question: {question}
-Cypher query:"""
+Cypher query:""" # noqa: E501
cypher_prompt = ChatPromptTemplate.from_messages(
[
@@ -88,7 +93,7 @@ def map_to_database(entities: Entities) -> Optional[str]:
cypher_response = (
RunnablePassthrough.assign(names=entity_chain)
| RunnablePassthrough.assign(
- entities_list=lambda x: map_to_database(x['names']['function']),
+ entities_list=lambda x: map_to_database(x["names"]["function"]),
schema=lambda _: graph.get_schema,
)
| cypher_prompt
@@ -100,13 +105,14 @@ def map_to_database(entities: Entities) -> Optional[str]:
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
Question: {question}
Cypher query: {query}
-Cypher Response: {response}"""
+Cypher Response: {response}""" # noqa: E501
response_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
- "Given an input question and Cypher response, convert it to a natural language answer. No pre-amble.",
+ "Given an input question and Cypher response, convert it to a natural"
+ " language answer. No pre-amble.",
),
("human", response_template),
]
diff --git a/templates/neo4j-cypher/neo4j_cypher/chain.py b/templates/neo4j-cypher/neo4j_cypher/chain.py
index befbb9ebb3053..730b2d2947dd8 100644
--- a/templates/neo4j-cypher/neo4j_cypher/chain.py
+++ b/templates/neo4j-cypher/neo4j_cypher/chain.py
@@ -1,9 +1,9 @@
+from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
-from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
# Connection to Neo4j
graph = Neo4jGraph()
@@ -24,7 +24,7 @@
{schema}
Question: {question}
-Cypher query:"""
+Cypher query:""" # noqa: E501
cypher_prompt = ChatPromptTemplate.from_messages(
[
@@ -49,13 +49,14 @@
response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response:
Question: {question}
Cypher query: {query}
-Cypher Response: {response}"""
+Cypher Response: {response}""" # noqa: E501
response_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
- "Given an input question and Cypher response, convert it to a natural language answer. No pre-amble.",
+ "Given an input question and Cypher response, convert it to a "
+ "natural language answer. No pre-amble.",
),
("human", response_template),
]
diff --git a/templates/neo4j-generation/main.py b/templates/neo4j-generation/main.py
index df2298613e002..578a18013fe6b 100644
--- a/templates/neo4j-generation/main.py
+++ b/templates/neo4j-generation/main.py
@@ -1,6 +1,5 @@
from neo4j_generation.chain import chain
-
if __name__ == "__main__":
text = "Harrison works at LangChain, which is located in San Francisco"
allowed_nodes = ["Person", "Organization", "Location"]
diff --git a/templates/neo4j-generation/neo4j_generation/chain.py b/templates/neo4j-generation/neo4j_generation/chain.py
index 8f7aa835879af..f980d25244674 100644
--- a/templates/neo4j-generation/neo4j_generation/chain.py
+++ b/templates/neo4j-generation/neo4j_generation/chain.py
@@ -1,11 +1,12 @@
-from typing import Optional, List
+from typing import List, Optional
+
from langchain.chains.openai_functions import (
create_structured_output_chain,
)
from langchain.chat_models import ChatOpenAI
-from langchain.prompts import ChatPromptTemplate
from langchain.graphs import Neo4jGraph
from langchain.graphs.graph_document import GraphDocument
+from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from neo4j_generation.utils import (
@@ -35,7 +36,7 @@ def get_extraction_chain(
If not provided, there won't be any specific restriction on node labels.
- allowed_rels (Optional[List[str]]): A list of relationship types that are allowed in the knowledge graph.
If not provided, there won't be any specific restriction on relationship types.
- """
+ """ # noqa: E501
prompt = ChatPromptTemplate.from_messages(
[
(
@@ -64,11 +65,12 @@ def get_extraction_chain(
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
## 5. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination.
- """,
+ """, # noqa: E501
),
(
"human",
- "Use the given format to extract information from the following input: {input}",
+ "Use the given format to extract information from the "
+ "following input: {input}",
),
("human", "Tip: Make sure to answer in the correct format"),
]
@@ -94,7 +96,7 @@ def chain(
Returns:
str: A confirmation message indicating the completion of the graph construction.
- """
+ """ # noqa: E501
# Extract graph data using OpenAI functions
extract_chain = get_extraction_chain(allowed_nodes, allowed_relationships)
data = extract_chain.run(text)
diff --git a/templates/neo4j-generation/neo4j_generation/utils.py b/templates/neo4j-generation/neo4j_generation/utils.py
index 00b086116b6dd..adaa9522889f3 100644
--- a/templates/neo4j-generation/neo4j_generation/utils.py
+++ b/templates/neo4j-generation/neo4j_generation/utils.py
@@ -1,9 +1,12 @@
+from typing import List, Optional
+
from langchain.graphs.graph_document import (
Node as BaseNode,
+)
+from langchain.graphs.graph_document import (
Relationship as BaseRelationship,
)
-from typing import List, Optional
-from langchain.pydantic_v1 import Field, BaseModel
+from langchain.pydantic_v1 import BaseModel, Field
class Property(BaseModel):
diff --git a/templates/neo4j-parent/ingest.py b/templates/neo4j-parent/ingest.py
index 7c8111925d267..f6a3f27e4ae8b 100644
--- a/templates/neo4j-parent/ingest.py
+++ b/templates/neo4j-parent/ingest.py
@@ -1,9 +1,10 @@
-from langchain.graphs import Neo4jGraph
-from langchain.vectorstores import Neo4jVector
+from pathlib import Path
+
from langchain.document_loaders import TextLoader
-from langchain.text_splitter import TokenTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
-from pathlib import Path
+from langchain.graphs import Neo4jGraph
+from langchain.text_splitter import TokenTextSplitter
+from langchain.vectorstores import Neo4jVector
txt_path = Path(__file__).parent / "dune.txt"
diff --git a/templates/neo4j-parent/main.py b/templates/neo4j-parent/main.py
index 990775d71cbf6..1c1b772d5265d 100644
--- a/templates/neo4j-parent/main.py
+++ b/templates/neo4j-parent/main.py
@@ -1,5 +1,4 @@
-from neo4j_parent.chain import chain
-
+from neo4j_parent.chain import chain
if __name__ == "__main__":
original_query = "What is the plot of the Dune?"
diff --git a/templates/neo4j-parent/neo4j_parent/chain.py b/templates/neo4j-parent/neo4j_parent/chain.py
index 703402b0728b8..59acbd662f883 100644
--- a/templates/neo4j-parent/neo4j_parent/chain.py
+++ b/templates/neo4j-parent/neo4j_parent/chain.py
@@ -1,8 +1,8 @@
-from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Neo4jVector
retrieval_query = """
diff --git a/templates/openai-functions-agent/openai_functions_agent/agent.py b/templates/openai-functions-agent/openai_functions_agent/agent.py
index 60cff54a1ef1f..ee8a2eb2ebdb3 100644
--- a/templates/openai-functions-agent/openai_functions_agent/agent.py
+++ b/templates/openai-functions-agent/openai_functions_agent/agent.py
@@ -1,15 +1,15 @@
from typing import List, Tuple
-from langchain.schema.messages import HumanMessage, AIMessage
-from langchain.chat_models import ChatOpenAI
+
from langchain.agents import AgentExecutor
-from langchain.utilities.tavily_search import TavilySearchAPIWrapper
-from langchain.tools.tavily_search import TavilySearchResults
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
-from langchain.tools.render import format_tool_to_openai_function
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
+from langchain.chat_models import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel
-
+from langchain.schema.messages import AIMessage, HumanMessage
+from langchain.tools.render import format_tool_to_openai_function
+from langchain.tools.tavily_search import TavilySearchResults
+from langchain.utilities.tavily_search import TavilySearchAPIWrapper
# Fake Tool
search = TavilySearchAPIWrapper()
@@ -18,17 +18,21 @@
tools = [tavily_tool]
llm = ChatOpenAI(temperature=0)
-prompt = ChatPromptTemplate.from_messages([
- ("system", "You are very powerful assistant, but bad at calculating lengths of words."),
- MessagesPlaceholder(variable_name="chat_history"),
- ("user", "{input}"),
- MessagesPlaceholder(variable_name="agent_scratchpad"),
-])
-
-llm_with_tools = llm.bind(
- functions=[format_tool_to_openai_function(t) for t in tools]
+prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "You are very powerful assistant, but bad at calculating lengths of words.",
+ ),
+ MessagesPlaceholder(variable_name="chat_history"),
+ ("user", "{input}"),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
)
+llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])
+
+
def _format_chat_history(chat_history: List[Tuple[str, str]]):
buffer = []
for human, ai in chat_history:
@@ -37,16 +41,25 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
return buffer
-agent = {
- "input": lambda x: x["input"],
- "chat_history": lambda x: _format_chat_history(x['chat_history']),
- "agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps']),
-} | prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
+agent = (
+ {
+ "input": lambda x: x["input"],
+ "chat_history": lambda x: _format_chat_history(x["chat_history"]),
+ "agent_scratchpad": lambda x: format_to_openai_functions(
+ x["intermediate_steps"]
+ ),
+ }
+ | prompt
+ | llm_with_tools
+ | OpenAIFunctionsAgentOutputParser()
+)
+
class AgentInput(BaseModel):
input: str
chat_history: List[Tuple[str, str]]
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_types(
input_type=AgentInput
)
diff --git a/templates/poetry.lock b/templates/poetry.lock
index a50742bbfa08f..3713751ae1e70 100644
--- a/templates/poetry.lock
+++ b/templates/poetry.lock
@@ -1,5 +1,144 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
+[[package]]
+name = "docopt"
+version = "0.6.2"
+description = "Pythonic argument parser, that will make you smile"
+optional = false
+python-versions = "*"
+files = [
+ {file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"},
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.1.3"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
+ {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "iniconfig"
+version = "2.0.0"
+description = "brain-dead simple config-ini parsing"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
+ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+]
+
+[[package]]
+name = "packaging"
+version = "23.2"
+description = "Core utilities for Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
+ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
+]
+
+[[package]]
+name = "pastel"
+version = "0.2.1"
+description = "Bring colors to your terminal."
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"},
+ {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"},
+]
+
+[[package]]
+name = "pluggy"
+version = "1.3.0"
+description = "plugin and hook calling mechanisms for python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"},
+ {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "poethepoet"
+version = "0.24.1"
+description = "A task runner that works well with poetry."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "poethepoet-0.24.1-py3-none-any.whl", hash = "sha256:3afa44b4fc7327df0dd912eda012604a072af2bb4d243fb0e41e8eca8dabf9ed"},
+ {file = "poethepoet-0.24.1.tar.gz", hash = "sha256:f5a386387c382f08890c273d13495938208a8ce91ab71536abf388c776c4f366"},
+]
+
+[package.dependencies]
+pastel = ">=0.2.1,<0.3.0"
+tomli = ">=1.2.2"
+
+[package.extras]
+poetry-plugin = ["poetry (>=1.0,<2.0)"]
+
+[[package]]
+name = "pytest"
+version = "7.4.3"
+description = "pytest: simple powerful testing with Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"},
+ {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=0.12,<2.0"
+tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+
+[[package]]
+name = "pytest-watch"
+version = "4.2.0"
+description = "Local continuous test runner with pytest and watchdog."
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-watch-4.2.0.tar.gz", hash = "sha256:06136f03d5b361718b8d0d234042f7b2f203910d8568f63df2f866b547b3d4b9"},
+]
+
+[package.dependencies]
+colorama = ">=0.3.3"
+docopt = ">=0.4.0"
+pytest = ">=2.6.4"
+watchdog = ">=0.6.0"
+
[[package]]
name = "ruff"
version = "0.1.2"
@@ -26,7 +165,57 @@ files = [
{file = "ruff-0.1.2.tar.gz", hash = "sha256:afd4785ae060ce6edcd52436d0c197628a918d6d09e3107a892a1bad6a4c6608"},
]
+[[package]]
+name = "tomli"
+version = "2.0.1"
+description = "A lil' TOML parser"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+]
+
+[[package]]
+name = "watchdog"
+version = "3.0.0"
+description = "Filesystem events monitoring"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:336adfc6f5cc4e037d52db31194f7581ff744b67382eb6021c868322e32eef41"},
+ {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a70a8dcde91be523c35b2bf96196edc5730edb347e374c7de7cd20c43ed95397"},
+ {file = "watchdog-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adfdeab2da79ea2f76f87eb42a3ab1966a5313e5a69a0213a3cc06ef692b0e96"},
+ {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b57a1e730af3156d13b7fdddfc23dea6487fceca29fc75c5a868beed29177ae"},
+ {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ade88d0d778b1b222adebcc0927428f883db07017618a5e684fd03b83342bd9"},
+ {file = "watchdog-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e447d172af52ad204d19982739aa2346245cc5ba6f579d16dac4bfec226d2e7"},
+ {file = "watchdog-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9fac43a7466eb73e64a9940ac9ed6369baa39b3bf221ae23493a9ec4d0022674"},
+ {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8ae9cda41fa114e28faf86cb137d751a17ffd0316d1c34ccf2235e8a84365c7f"},
+ {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f70b4aa53bd743729c7475d7ec41093a580528b100e9a8c5b5efe8899592fc"},
+ {file = "watchdog-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4f94069eb16657d2c6faada4624c39464f65c05606af50bb7902e036e3219be3"},
+ {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c5f84b5194c24dd573fa6472685b2a27cc5a17fe5f7b6fd40345378ca6812e3"},
+ {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3aa7f6a12e831ddfe78cdd4f8996af9cf334fd6346531b16cec61c3b3c0d8da0"},
+ {file = "watchdog-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:233b5817932685d39a7896b1090353fc8efc1ef99c9c054e46c8002561252fb8"},
+ {file = "watchdog-3.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13bbbb462ee42ec3c5723e1205be8ced776f05b100e4737518c67c8325cf6100"},
+ {file = "watchdog-3.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8f3ceecd20d71067c7fd4c9e832d4e22584318983cabc013dbf3f70ea95de346"},
+ {file = "watchdog-3.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9d8c8ec7efb887333cf71e328e39cffbf771d8f8f95d308ea4125bf5f90ba64"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:d00e6be486affb5781468457b21a6cbe848c33ef43f9ea4a73b4882e5f188a44"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:c07253088265c363d1ddf4b3cdb808d59a0468ecd017770ed716991620b8f77a"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:5113334cf8cf0ac8cd45e1f8309a603291b614191c9add34d33075727a967709"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:51f90f73b4697bac9c9a78394c3acbbd331ccd3655c11be1a15ae6fe289a8c83"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:ba07e92756c97e3aca0912b5cbc4e5ad802f4557212788e72a72a47ff376950d"},
+ {file = "watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d429c2430c93b7903914e4db9a966c7f2b068dd2ebdd2fa9b9ce094c7d459f33"},
+ {file = "watchdog-3.0.0-py3-none-win32.whl", hash = "sha256:3ed7c71a9dccfe838c2f0b6314ed0d9b22e77d268c67e015450a29036a81f60f"},
+ {file = "watchdog-3.0.0-py3-none-win_amd64.whl", hash = "sha256:4c9956d27be0bb08fc5f30d9d0179a855436e655f046d288e2bcc11adfae893c"},
+ {file = "watchdog-3.0.0-py3-none-win_ia64.whl", hash = "sha256:5d9f3a10e02d7371cd929b5d8f11e87d4bad890212ed3901f9b4d68767bee759"},
+ {file = "watchdog-3.0.0.tar.gz", hash = "sha256:4d98a320595da7a7c5a18fc48cb633c2e73cda78f93cac2ef42d42bf609a33f9"},
+]
+
+[package.extras]
+watchmedo = ["PyYAML (>=3.10)"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "7a86271e260f3ac3de6446960b301dc6c854f9d7f9544774d81eaa13be511838"
+content-hash = "e00055a76b5e7a5dd6afd6c65073084a9a0f988a8d30e24be2606c048bd25686"
diff --git a/templates/pyproject.toml b/templates/pyproject.toml
index 7d9f6396dfbe0..2e682bbb76668 100644
--- a/templates/pyproject.toml
+++ b/templates/pyproject.toml
@@ -9,14 +9,19 @@ readme = "README.md"
python = "^3.10"
-[build-system]
-requires = ["poetry-core"]
-build-backend = "poetry.core.masonry.api"
+# dev, test, lint, typing
+[tool.poetry.group.dev.dependencies]
+poethepoet = "^0.24.1"
+pytest-watch = "^4.2.0"
+
+[tool.poetry.group.test.dependencies]
+pytest = "^7.4.3"
[tool.poetry.group.lint.dependencies]
ruff = "^0.1"
+[tool.poetry.group.typing.dependencies]
[tool.ruff]
select = [
@@ -24,3 +29,14 @@ select = [
"F", # pyflakes
"I", # isort
]
+
+[tool.poe.tasks]
+test = "poetry run pytest"
+watch = "poetry run ptw"
+lint = "poetry run ruff ."
+format = "poetry run ruff . --fix"
+
+
+[build-system]
+requires = ["poetry-core"]
+build-backend = "poetry.core.masonry.api"
\ No newline at end of file
diff --git a/templates/rag-chroma-private/rag_chroma_private/chain.py b/templates/rag-chroma-private/rag_chroma_private/chain.py
index 0e6b659aa7cd4..397a1589b6747 100644
--- a/templates/rag-chroma-private/rag_chroma_private/chain.py
+++ b/templates/rag-chroma-private/rag_chroma_private/chain.py
@@ -1,28 +1,31 @@
-from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOllama
-from langchain.prompts import ChatPromptTemplate
-from langchain.embeddings import GPT4AllEmbeddings
-from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
# Load
from langchain.document_loaders import WebBaseLoader
+from langchain.embeddings import GPT4AllEmbeddings
+from langchain.prompts import ChatPromptTemplate
+from langchain.schema.output_parser import StrOutputParser
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain.vectorstores import Chroma
+
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()
# Split
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
# Add to vectorDB
-vectorstore = Chroma.from_documents(documents=all_splits,
- collection_name="rag-private",
- embedding=GPT4AllEmbeddings(),
- )
+vectorstore = Chroma.from_documents(
+ documents=all_splits,
+ collection_name="rag-private",
+ embedding=GPT4AllEmbeddings(),
+)
retriever = vectorstore.as_retriever()
-# Prompt
+# Prompt
# Optionally, pull from the Hub
# from langchain import hub
# prompt = hub.pull("rlm/rag-prompt")
diff --git a/templates/rag-chroma/rag_chroma/chain.py b/templates/rag-chroma/rag_chroma/chain.py
index ecf4a9ca5e1fb..96a46fd4c68a4 100644
--- a/templates/rag-chroma/rag_chroma/chain.py
+++ b/templates/rag-chroma/rag_chroma/chain.py
@@ -1,9 +1,8 @@
-from operator import itemgetter
-from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import Chroma
# Example for document loading (from url), splitting, and creating vectostore
diff --git a/templates/rag-conversation/rag_conversation/chain.py b/templates/rag-conversation/rag_conversation/chain.py
index a0f0be1ba4c9c..3ed98288b4c3f 100644
--- a/templates/rag-conversation/rag_conversation/chain.py
+++ b/templates/rag-conversation/rag_conversation/chain.py
@@ -1,15 +1,21 @@
import os
-from typing import Tuple, List
-from pydantic import BaseModel
from operator import itemgetter
-from langchain.vectorstores import Pinecone
+from typing import List, Tuple
+
from langchain.chat_models import ChatOpenAI
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.embeddings import OpenAIEmbeddings
-from langchain.schema import format_document, AIMessage, HumanMessage
+from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.prompt import PromptTemplate
+from langchain.schema import AIMessage, HumanMessage, format_document
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda, RunnableMap
+from langchain.schema.runnable import (
+ RunnableBranch,
+ RunnableLambda,
+ RunnableMap,
+ RunnablePassthrough,
+)
+from langchain.vectorstores import Pinecone
+from pydantic import BaseModel
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
@@ -44,7 +50,7 @@
Chat History:
{chat_history}
Follow Up Input: {question}
-Standalone question:"""
+Standalone question:""" # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
# RAG answer synthesis prompt
@@ -52,18 +58,25 @@
{context}
"""
-ANSWER_PROMPT = ChatPromptTemplate.from_messages([
- ("system",template),
- MessagesPlaceholder(variable_name="chat_history"),
- ("user", "{question}")
-])
+ANSWER_PROMPT = ChatPromptTemplate.from_messages(
+ [
+ ("system", template),
+ MessagesPlaceholder(variable_name="chat_history"),
+ ("user", "{question}"),
+ ]
+)
# Conversational Retrieval Chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
-def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
+
+
+def _combine_documents(
+ docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
+):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
+
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
buffer = []
for human, ai in chat_history:
@@ -71,6 +84,7 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
buffer.append(AIMessage(content=ai))
return buffer
+
# User input
class ChatHistory(BaseModel):
chat_history: List[Tuple[str, str]]
@@ -78,24 +92,28 @@ class ChatHistory(BaseModel):
_search_query = RunnableBranch(
- # If input includes chat_history, we condense it with the follow-up question
- (
- RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
- run_name="HasChatHistoryCheck"
- ), # Condense follow-up question and chat into a standalone_question
- RunnablePassthrough.assign(
- chat_history=lambda x: _format_chat_history(x['chat_history'])
- ) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),
- ),
- # Else, we have no chat history, so just pass through the question
- RunnableLambda(itemgetter("question"))
-
- )
-
-_inputs = RunnableMap({
- "question": lambda x: x["question"],
- "chat_history": lambda x: _format_chat_history(x['chat_history']),
- "context": _search_query | retriever | _combine_documents
-}).with_types(input_type=ChatHistory)
+ # If input includes chat_history, we condense it with the follow-up question
+ (
+ RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
+ run_name="HasChatHistoryCheck"
+ ), # Condense follow-up question and chat into a standalone_question
+ RunnablePassthrough.assign(
+ chat_history=lambda x: _format_chat_history(x["chat_history"])
+ )
+ | CONDENSE_QUESTION_PROMPT
+ | ChatOpenAI(temperature=0)
+ | StrOutputParser(),
+ ),
+ # Else, we have no chat history, so just pass through the question
+ RunnableLambda(itemgetter("question")),
+)
+
+_inputs = RunnableMap(
+ {
+ "question": lambda x: x["question"],
+ "chat_history": lambda x: _format_chat_history(x["chat_history"]),
+ "context": _search_query | retriever | _combine_documents,
+ }
+).with_types(input_type=ChatHistory)
chain = _inputs | ANSWER_PROMPT | ChatOpenAI() | StrOutputParser()
diff --git a/templates/rag-elasticsearch/ingest.py b/templates/rag-elasticsearch/ingest.py
index 0ba72df0afacb..5b6db39d09edf 100644
--- a/templates/rag-elasticsearch/ingest.py
+++ b/templates/rag-elasticsearch/ingest.py
@@ -1,8 +1,9 @@
+import os
+
from langchain.document_loaders import JSONLoader
from langchain.embeddings import HuggingFaceEmbeddings
-from langchain.vectorstores.elasticsearch import ElasticsearchStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
-import os
+from langchain.vectorstores.elasticsearch import ElasticsearchStore
ELASTIC_CLOUD_ID = os.getenv("ELASTIC_CLOUD_ID")
ELASTIC_USERNAME = os.getenv("ELASTIC_USERNAME", "elastic")
diff --git a/templates/rag-elasticsearch/main.py b/templates/rag-elasticsearch/main.py
index 6ca4a83b79a52..4034ab08f26f4 100644
--- a/templates/rag-elasticsearch/main.py
+++ b/templates/rag-elasticsearch/main.py
@@ -23,7 +23,9 @@
"question": follow_up_question,
"chat_history": [
"What is the nasa sales team?",
- "The sales team of NASA consists of Laura Martinez, the Area Vice-President of North America, and Gary Johnson, the Area Vice-President of South America. (Sales Organization Overview)",
+ "The sales team of NASA consists of Laura Martinez, the Area "
+ "Vice-President of North America, and Gary Johnson, the Area "
+ "Vice-President of South America. (Sales Organization Overview)",
],
}
)
diff --git a/templates/rag-elasticsearch/rag_elasticsearch/chain.py b/templates/rag-elasticsearch/rag_elasticsearch/chain.py
index 220a546b5af0d..9dd41ae9e965e 100644
--- a/templates/rag-elasticsearch/rag_elasticsearch/chain.py
+++ b/templates/rag-elasticsearch/rag_elasticsearch/chain.py
@@ -1,13 +1,15 @@
+from operator import itemgetter
+from typing import List, Tuple
+
from langchain.chat_models import ChatOpenAI
-from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableMap
from langchain.embeddings import HuggingFaceEmbeddings
-from langchain.vectorstores.elasticsearch import ElasticsearchStore
from langchain.schema import format_document
-from typing import Tuple, List
-from operator import itemgetter
-from .prompts import CONDENSE_QUESTION_PROMPT, LLM_CONTEXT_PROMPT, DOCUMENT_PROMPT
+from langchain.schema.output_parser import StrOutputParser
+from langchain.schema.runnable import RunnableMap, RunnablePassthrough
+from langchain.vectorstores.elasticsearch import ElasticsearchStore
+
from .connection import es_connection_details
+from .prompts import CONDENSE_QUESTION_PROMPT, DOCUMENT_PROMPT, LLM_CONTEXT_PROMPT
# Setup connecting to Elasticsearch
vectorstore = ElasticsearchStore(
diff --git a/templates/rag-elasticsearch/rag_elasticsearch/prompts.py b/templates/rag-elasticsearch/rag_elasticsearch/prompts.py
index 431692fb4f8f0..ca1e588242d11 100644
--- a/templates/rag-elasticsearch/rag_elasticsearch/prompts.py
+++ b/templates/rag-elasticsearch/rag_elasticsearch/prompts.py
@@ -6,7 +6,7 @@
Chat History:
{chat_history}
Follow Up Input: {question}
-"""
+""" # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
condense_question_prompt_template
)
@@ -23,7 +23,7 @@
{context}
----
Question: {question}
-"""
+""" # noqa: E501
LLM_CONTEXT_PROMPT = ChatPromptTemplate.from_template(llm_context_prompt_template)
diff --git a/templates/rag-fusion/ingest.py b/templates/rag-fusion/ingest.py
index a21fff11572b5..9604658e4daf0 100644
--- a/templates/rag-fusion/ingest.py
+++ b/templates/rag-fusion/ingest.py
@@ -1,8 +1,8 @@
import pinecone
-from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings
+from langchain.vectorstores import Pinecone
-pinecone.init(api_key="...",environment="...")
+pinecone.init(api_key="...", environment="...")
all_documents = {
"doc1": "Climate change and economic impact.",
@@ -14,7 +14,9 @@
"doc7": "Climate change: The science and models.",
"doc8": "Global warming: A subset of climate change.",
"doc9": "How climate change affects daily weather.",
- "doc10": "The history of climate change activism."
+ "doc10": "The history of climate change activism.",
}
-Pinecone.from_texts(list(all_documents.values()), OpenAIEmbeddings(), index_name='rag-fusion')
+Pinecone.from_texts(
+ list(all_documents.values()), OpenAIEmbeddings(), index_name="rag-fusion"
+)
diff --git a/templates/rag-fusion/main.py b/templates/rag-fusion/main.py
index a283e8b2c1c9a..642b48eaa362b 100644
--- a/templates/rag-fusion/main.py
+++ b/templates/rag-fusion/main.py
@@ -1,5 +1,4 @@
-from rag_fusion.chain import chain
-
+from rag_fusion.chain import chain
if __name__ == "__main__":
original_query = "impact of climate change"
diff --git a/templates/rag-fusion/rag_fusion/chain.py b/templates/rag-fusion/rag_fusion/chain.py
index e6cedfaa9057a..6d8f2892d2c6e 100644
--- a/templates/rag-fusion/rag_fusion/chain.py
+++ b/templates/rag-fusion/rag_fusion/chain.py
@@ -1,11 +1,11 @@
-from langchain.chat_models import ChatOpenAI
-from langchain.prompts import ChatPromptTemplate
-from langchain.schema.output_parser import StrOutputParser
-from langchain import hub
import pinecone
-from langchain.vectorstores import Pinecone
+from langchain import hub
+from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.load import dumps, loads
+from langchain.schema.output_parser import StrOutputParser
+from langchain.vectorstores import Pinecone
+
def reciprocal_rank_fusion(results: list[list], k=60):
fused_scores = {}
@@ -15,19 +15,29 @@ def reciprocal_rank_fusion(results: list[list], k=60):
doc_str = dumps(doc)
if doc_str not in fused_scores:
fused_scores[doc_str] = 0
- previous_score = fused_scores[doc_str]
fused_scores[doc_str] += 1 / (rank + k)
-
- reranked_results = [(loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)]
- return reranked_results
+
+ reranked_results = [
+ (loads(doc), score)
+ for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
+ ]
+ return reranked_results
+
pinecone.init(api_key="...", environment="...")
-prompt = hub.pull('langchain-ai/rag-fusion-query-generation')
+prompt = hub.pull("langchain-ai/rag-fusion-query-generation")
-generate_queries = prompt | ChatOpenAI(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))
+generate_queries = (
+ prompt | ChatOpenAI(temperature=0) | StrOutputParser() | (lambda x: x.split("\n"))
+)
vectorstore = Pinecone.from_existing_index("rag-fusion", OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
-chain = {"original_query": lambda x: x} | generate_queries | retriever.map() | reciprocal_rank_fusion
+chain = (
+ {"original_query": lambda x: x}
+ | generate_queries
+ | retriever.map()
+ | reciprocal_rank_fusion
+)
diff --git a/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py b/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py
index 70891a1ad5a0f..b676de6e9c06c 100644
--- a/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py
+++ b/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py
@@ -1,13 +1,12 @@
-import os
-import pinecone
-from operator import itemgetter
-from langchain.vectorstores import Pinecone
-from langchain.prompts import ChatPromptTemplate
+import os
+
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
-from langchain.schema.output_parser import StrOutputParser
+from langchain.prompts import ChatPromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
+from langchain.schema.output_parser import StrOutputParser
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
+from langchain.vectorstores import Pinecone
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
diff --git a/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py b/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py
index 6d6701d731f8c..46171f4f16523 100644
--- a/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py
+++ b/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py
@@ -1,14 +1,13 @@
-import os
-import pinecone
-from operator import itemgetter
-from langchain.vectorstores import Pinecone
-from langchain.prompts import ChatPromptTemplate
+import os
+
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
-from langchain.schema.output_parser import StrOutputParser
+from langchain.prompts import ChatPromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
+from langchain.schema.output_parser import StrOutputParser
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
+from langchain.vectorstores import Pinecone
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
@@ -38,7 +37,7 @@
vectorstore = Pinecone.from_existing_index(PINECONE_INDEX_NAME, OpenAIEmbeddings())
# Get k=10 docs
-retriever = vectorstore.as_retriever(search_kwargs={"k":10})
+retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
# Re-rank
compressor = CohereRerank()
@@ -56,7 +55,9 @@
# RAG
model = ChatOpenAI()
chain = (
- RunnableParallel({"context": compression_retriever, "question": RunnablePassthrough()})
+ RunnableParallel(
+ {"context": compression_retriever, "question": RunnablePassthrough()}
+ )
| prompt
| model
| StrOutputParser()
diff --git a/templates/rag-pinecone/rag_pinecone/chain.py b/templates/rag-pinecone/rag_pinecone/chain.py
index 6f2cc701ed0ca..6777010d2a149 100644
--- a/templates/rag-pinecone/rag_pinecone/chain.py
+++ b/templates/rag-pinecone/rag_pinecone/chain.py
@@ -1,12 +1,11 @@
-import os
-import pinecone
-from operator import itemgetter
-from langchain.vectorstores import Pinecone
-from langchain.prompts import ChatPromptTemplate
+import os
+
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
+from langchain.vectorstores import Pinecone
if os.environ.get("PINECONE_API_KEY", None) is None:
raise Exception("Missing `PINECONE_API_KEY` environment variable.")
diff --git a/templates/rag-semi-structured/rag_semi_structured/chain.py b/templates/rag-semi-structured/rag_semi_structured/chain.py
index 6622e83234e72..e5e923ab55a04 100644
--- a/templates/rag-semi-structured/rag_semi_structured/chain.py
+++ b/templates/rag-semi-structured/rag_semi_structured/chain.py
@@ -1,33 +1,36 @@
# Load
import uuid
+
from langchain.chat_models import ChatOpenAI
+from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
+from langchain.retrievers.multi_vector import MultiVectorRetriever
+from langchain.schema.document import Document
from langchain.schema.output_parser import StrOutputParser
-from langchain.vectorstores import Chroma
+from langchain.schema.runnable import RunnablePassthrough
from langchain.storage import InMemoryStore
+from langchain.vectorstores import Chroma
from unstructured.partition.pdf import partition_pdf
-from langchain.schema.document import Document
-from langchain.embeddings import OpenAIEmbeddings
-from langchain.retrievers.multi_vector import MultiVectorRetriever
-from langchain.schema.runnable import RunnablePassthrough
# Path to docs
path = "docs"
-raw_pdf_elements = partition_pdf(filename=path+"LLaMA2.pdf",
- # Unstructured first finds embedded image blocks
- extract_images_in_pdf=False,
- # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
- # Titles are any sub-section of the document
- infer_table_structure=True,
- # Post processing to aggregate text once we have the title
- chunking_strategy="by_title",
- # Chunking params to aggregate text blocks
- # Attempt to create a new chunk 3800 chars
- # Attempt to keep chunks > 2000 chars
- max_characters=4000,
- new_after_n_chars=3800,
- combine_text_under_n_chars=2000,
- image_output_dir_path=path)
+raw_pdf_elements = partition_pdf(
+ filename=path + "LLaMA2.pdf",
+ # Unstructured first finds embedded image blocks
+ extract_images_in_pdf=False,
+ # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
+ # Titles are any sub-section of the document
+ infer_table_structure=True,
+ # Post processing to aggregate text once we have the title
+ chunking_strategy="by_title",
+ # Chunking params to aggregate text blocks
+ # Attempt to create a new chunk 3800 chars
+ # Attempt to keep chunks > 2000 chars
+ max_characters=4000,
+ new_after_n_chars=3800,
+ combine_text_under_n_chars=2000,
+ image_output_dir_path=path,
+)
# Categorize by type
tables = []
@@ -40,26 +43,23 @@
# Summarize
-prompt_text="""You are an assistant tasked with summarizing tables and text. \
+prompt_text = """You are an assistant tasked with summarizing tables and text. \
Give a concise summary of the table or text. Table or text chunk: {element} """
-prompt = ChatPromptTemplate.from_template(prompt_text)
-model = ChatOpenAI(temperature=0,model="gpt-4")
-summarize_chain = {"element": lambda x:x} | prompt | model | StrOutputParser()
+prompt = ChatPromptTemplate.from_template(prompt_text)
+model = ChatOpenAI(temperature=0, model="gpt-4")
+summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Apply
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
# To save time / cost, only do text summaries if chunk sizes are large
# text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
-# We can just assign text_summaries to the raw texts
+# We can just assign text_summaries to the raw texts
text_summaries = texts
# Use multi vector retriever
# The vectorstore to use to index the child chunks
-vectorstore = Chroma(
- collection_name="summaries",
- embedding_function=OpenAIEmbeddings()
-)
+vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
store = InMemoryStore()
@@ -67,20 +67,26 @@
# The retriever (empty to start)
retriever = MultiVectorRetriever(
- vectorstore=vectorstore,
- docstore=store,
+ vectorstore=vectorstore,
+ docstore=store,
id_key=id_key,
)
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
-summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]
+summary_texts = [
+ Document(page_content=s, metadata={id_key: doc_ids[i]})
+ for i, s in enumerate(text_summaries)
+]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))
# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
-summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]
+summary_tables = [
+ Document(page_content=s, metadata={id_key: table_ids[i]})
+ for i, s in enumerate(table_summaries)
+]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))
@@ -90,16 +96,16 @@
template = """Answer the question based only on the following context, which can include text and tables:
{context}
Question: {question}
-"""
+""" # noqa: E501
prompt = ChatPromptTemplate.from_template(template)
# LLM
-model = ChatOpenAI(temperature=0,model="gpt-4")
+model = ChatOpenAI(temperature=0, model="gpt-4")
# RAG pipeline
chain = (
- {"context": retriever, "question": RunnablePassthrough()}
- | prompt
- | model
+ {"context": retriever, "question": RunnablePassthrough()}
+ | prompt
+ | model
| StrOutputParser()
-)
\ No newline at end of file
+)
diff --git a/templates/rag-supabase/rag_supabase/chain.py b/templates/rag-supabase/rag_supabase/chain.py
index ef35319dcdb12..e116e840c120d 100644
--- a/templates/rag-supabase/rag_supabase/chain.py
+++ b/templates/rag-supabase/rag_supabase/chain.py
@@ -1,13 +1,12 @@
import os
-from langchain.prompts import ChatPromptTemplate
+
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
-
-from supabase.client import create_client
-from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores.supabase import SupabaseVectorStore
+from supabase.client import create_client
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
@@ -19,7 +18,7 @@
client=supabase,
embedding=embeddings,
table_name="documents",
- query_name="match_documents"
+ query_name="match_documents",
)
retriever = vectorstore.as_retriever()
diff --git a/templates/rewrite-retrieve-read/main.py b/templates/rewrite-retrieve-read/main.py
index 3b0ce28925137..deedf470341fe 100644
--- a/templates/rewrite-retrieve-read/main.py
+++ b/templates/rewrite-retrieve-read/main.py
@@ -1,5 +1,4 @@
from rewrite_retrieve_read.chain import chain
-
if __name__ == "__main__":
chain.invoke("man that sam bankman fried trial was crazy! what is langchain?")
diff --git a/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py b/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py
index 7094f9132c3af..456db4204d3cb 100644
--- a/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py
+++ b/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py
@@ -1,9 +1,8 @@
-from operator import itemgetter
-from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
+from langchain.schema.runnable import RunnablePassthrough
from langchain.utilities import DuckDuckGoSearchAPIWrapper
template = """Answer the users question based only on the following context:
diff --git a/templates/self-query-supabase/self_query_supabase/chain.py b/templates/self-query-supabase/self_query_supabase/chain.py
index c32f5154a6a01..eca0c0e82c4b8 100644
--- a/templates/self-query-supabase/self_query_supabase/chain.py
+++ b/templates/self-query-supabase/self_query_supabase/chain.py
@@ -1,15 +1,12 @@
import os
+
+from langchain.chains.query_constructor.base import AttributeInfo
+from langchain.embeddings import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.retrievers.self_query.base import SelfQueryRetriever
-from langchain.chat_models import ChatOpenAI
-from langchain.embeddings import OpenAIEmbeddings
-from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
-from langchain.chains.query_constructor.base import AttributeInfo
-
-from supabase.client import create_client
-from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores.supabase import SupabaseVectorStore
+from supabase.client import create_client
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
@@ -21,7 +18,7 @@
client=supabase,
embedding=embeddings,
table_name="documents",
- query_name="match_documents"
+ query_name="match_documents",
)
# Adjust this based on the metadata you store in the `metadata` JSON column
@@ -51,14 +48,7 @@
llm = OpenAI(temperature=0)
retriever = SelfQueryRetriever.from_llm(
- llm,
- vectorstore,
- document_content_description,
- metadata_field_info,
- verbose=True
+ llm, vectorstore, document_content_description, metadata_field_info, verbose=True
)
-chain = (
- RunnableParallel({"query": RunnablePassthrough()})
- | retriever
-)
+chain = RunnableParallel({"query": RunnablePassthrough()}) | retriever
diff --git a/templates/sql-llama2/sql_llama2/chain.py b/templates/sql-llama2/sql_llama2/chain.py
index 235301252bacd..7e53f5d147c43 100644
--- a/templates/sql-llama2/sql_llama2/chain.py
+++ b/templates/sql-llama2/sql_llama2/chain.py
@@ -1,23 +1,26 @@
+from pathlib import Path
+
from langchain.llms import Replicate
+from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
-from langchain.prompts import ChatPromptTemplate
+from langchain.utilities import SQLDatabase
# make sure to set REPLICATE_API_TOKEN in your environment
# use llama-2-13b model in replicate
-replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d" # noqa: E501
llm = Replicate(
model=replicate_id,
model_kwargs={"temperature": 0.01, "max_length": 500, "top_p": 1},
)
-from pathlib import Path
-from langchain.utilities import SQLDatabase
+
db_path = Path(__file__).parent / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
+
def get_schema(_):
return db.get_table_info()
@@ -30,7 +33,7 @@ def run_query(query):
{schema}
Question: {question}
-SQL Query:"""
+SQL Query:""" # noqa: E501
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a SQL query. No pre-amble."),
@@ -50,13 +53,14 @@ def run_query(query):
Question: {question}
SQL Query: {query}
-SQL Response: {response}"""
+SQL Response: {response}""" # noqa: E501
prompt_response = ChatPromptTemplate.from_messages(
[
(
"system",
- "Given an input question and SQL response, convert it to a natural language answer. No pre-amble.",
+ "Given an input question and SQL response, convert it to a natural "
+ "language answer. No pre-amble.",
),
("human", template_response),
]
diff --git a/templates/sql-llamacpp/sql_llamacpp/chain.py b/templates/sql-llamacpp/sql_llamacpp/chain.py
index bf8e796d0e53d..4a00069988dd4 100644
--- a/templates/sql-llamacpp/sql_llamacpp/chain.py
+++ b/templates/sql-llamacpp/sql_llamacpp/chain.py
@@ -1,11 +1,15 @@
-from langchain.llms import LlamaCpp
-from langchain.prompts import ChatPromptTemplate
-from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough
-
# Get LLM
import os
+from pathlib import Path
+
import requests
+from langchain.llms import LlamaCpp
+from langchain.memory import ConversationBufferMemory
+from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
+from langchain.schema.output_parser import StrOutputParser
+from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
+from langchain.utilities import SQLDatabase
+
# File name and URL
file_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
url = "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
@@ -15,7 +19,7 @@
# Download the file
response = requests.get(url)
response.raise_for_status() # Raise an exception for HTTP errors
- with open(file_name, 'wb') as f:
+ with open(file_name, "wb") as f:
f.write(response.content)
print(f"'{file_name}' has been downloaded.")
else:
@@ -24,23 +28,27 @@
# Add the LLM downloaded from HF
model_path = file_name
n_gpu_layers = 1 # Metal set to 1 is enough.
-n_batch = 512 # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
+
+# Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
+n_batch = 512
+
llm = LlamaCpp(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=2048,
- f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls
+ # f16_kv MUST set to True
+ # otherwise you will run into problem after a couple of calls
+ f16_kv=True,
verbose=True,
)
-from pathlib import Path
-from langchain.utilities import SQLDatabase
db_path = Path(__file__).parent / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
+
def get_schema(_):
return db.get_table_info()
@@ -48,39 +56,43 @@ def get_schema(_):
def run_query(query):
return db.run(query)
+
# Prompt
-from langchain.memory import ConversationBufferMemory
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
+
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
Question: {question}
-SQL Query:"""
-prompt = ChatPromptTemplate.from_messages([
- ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
- MessagesPlaceholder(variable_name="history"),
- ("human", template)
-])
+SQL Query:""" # noqa: E501
+prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
+ MessagesPlaceholder(variable_name="history"),
+ ("human", template),
+ ]
+)
memory = ConversationBufferMemory(return_messages=True)
-# Chain to query with memory
-from langchain.schema.runnable import RunnableLambda
+# Chain to query with memory
sql_chain = (
RunnablePassthrough.assign(
- schema=get_schema,
- history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"])
- )| prompt
+ schema=get_schema,
+ history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
+ )
+ | prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
+
def save(input_output):
output = {"output": input_output.pop("output")}
memory.save_context(input_output, output)
- return output['output']
-
+ return output["output"]
+
+
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
# Chain to answer
@@ -89,18 +101,24 @@ def save(input_output):
Question: {question}
SQL Query: {query}
-SQL Response: {response}"""
-prompt_response = ChatPromptTemplate.from_messages([
- ("system", "Given an input question and SQL response, convert it to a natural language answer. No pre-amble."),
- ("human", template)
-])
+SQL Response: {response}""" # noqa: E501
+prompt_response = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "Given an input question and SQL response, convert it to a natural "
+ "language answer. No pre-amble.",
+ ),
+ ("human", template),
+ ]
+)
chain = (
- RunnablePassthrough.assign(query=sql_response_memory)
+ RunnablePassthrough.assign(query=sql_response_memory)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
- | prompt_response
+ | prompt_response
| llm
)
diff --git a/templates/sql-ollama/sql_ollama/chain.py b/templates/sql-ollama/sql_ollama/chain.py
index 248e5c8a5b46c..7cf55746b8878 100644
--- a/templates/sql-ollama/sql_ollama/chain.py
+++ b/templates/sql-ollama/sql_ollama/chain.py
@@ -1,58 +1,67 @@
+from pathlib import Path
+
from langchain.chat_models import ChatOllama
-from langchain.prompts import ChatPromptTemplate
+from langchain.memory import ConversationBufferMemory
+from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import RunnablePassthrough
+from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
+from langchain.utilities import SQLDatabase
# Add the LLM downloaded from Ollama
ollama_llm = "llama2:13b-chat"
llm = ChatOllama(model=ollama_llm)
-from pathlib import Path
-from langchain.utilities import SQLDatabase
+
db_path = Path(__file__).parent / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0)
+
def get_schema(_):
return db.get_table_info()
+
def run_query(query):
return db.run(query)
+
# Prompt
-from langchain.memory import ConversationBufferMemory
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
+
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
Question: {question}
-SQL Query:"""
-prompt = ChatPromptTemplate.from_messages([
- ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
- MessagesPlaceholder(variable_name="history"),
- ("human", template)
-])
+SQL Query:""" # noqa: E501
+prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
+ MessagesPlaceholder(variable_name="history"),
+ ("human", template),
+ ]
+)
memory = ConversationBufferMemory(return_messages=True)
-# Chain to query with memory
-from langchain.schema.runnable import RunnableLambda
+# Chain to query with memory
sql_chain = (
RunnablePassthrough.assign(
- schema=get_schema,
- history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"])
- )| prompt
+ schema=get_schema,
+ history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
+ )
+ | prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
+
def save(input_output):
output = {"output": input_output.pop("output")}
memory.save_context(input_output, output)
- return output['output']
-
+ return output["output"]
+
+
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
# Chain to answer
@@ -61,18 +70,24 @@ def save(input_output):
Question: {question}
SQL Query: {query}
-SQL Response: {response}"""
-prompt_response = ChatPromptTemplate.from_messages([
- ("system", "Given an input question and SQL response, convert it to a natural language answer. No pre-amble."),
- ("human", template)
-])
+SQL Response: {response}""" # noqa: E501
+prompt_response = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "Given an input question and SQL response, convert it to a natural "
+ "language answer. No pre-amble.",
+ ),
+ ("human", template),
+ ]
+)
chain = (
- RunnablePassthrough.assign(query=sql_response_memory)
+ RunnablePassthrough.assign(query=sql_response_memory)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
- | prompt_response
+ | prompt_response
| llm
)
diff --git a/templates/stepback-qa-prompting/main.py b/templates/stepback-qa-prompting/main.py
index 70cb5d5016947..05e0455b49b72 100644
--- a/templates/stepback-qa-prompting/main.py
+++ b/templates/stepback-qa-prompting/main.py
@@ -1,5 +1,4 @@
from stepback_qa_prompting.chain import chain
-
if __name__ == "__main__":
chain.invoke({"question": "was chatgpt around while trump was president?"})
diff --git a/templates/stepback-qa-prompting/stepback_qa_prompting/chain.py b/templates/stepback-qa-prompting/stepback_qa_prompting/chain.py
index 6fd256e2fb855..f25802ef264fe 100644
--- a/templates/stepback-qa-prompting/stepback_qa_prompting/chain.py
+++ b/templates/stepback-qa-prompting/stepback_qa_prompting/chain.py
@@ -4,9 +4,9 @@
from langchain.schema.runnable import RunnableLambda
from langchain.utilities import DuckDuckGoSearchAPIWrapper
-
search = DuckDuckGoSearchAPIWrapper(max_results=4)
+
def retriever(query):
return search.run(query)
@@ -15,11 +15,11 @@ def retriever(query):
examples = [
{
"input": "Could the members of The Police perform lawful arrests?",
- "output": "what can the members of The Police do?"
+ "output": "what can the members of The Police do?",
},
{
- "input": "Jan Sindel’s was born in what country?",
- "output": "what is Jan Sindel’s personal history?"
+ "input": "Jan Sindel’s was born in what country?",
+ "output": "what is Jan Sindel’s personal history?",
},
]
# We now transform these to example messages
@@ -34,13 +34,20 @@ def retriever(query):
examples=examples,
)
-prompt = ChatPromptTemplate.from_messages([
- ("system", """You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Here are a few examples:"""),
- # Few shot examples
- few_shot_prompt,
- # New question
- ("user", "{question}"),
-])
+prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "You are an expert at world knowledge. Your task is to step back "
+ "and paraphrase a question to a more generic step-back question, which "
+ "is easier to answer. Here are a few examples:",
+ ),
+ # Few shot examples
+ few_shot_prompt,
+ # New question
+ ("user", "{question}"),
+ ]
+)
question_gen = prompt | ChatOpenAI(temperature=0) | StrOutputParser()
@@ -50,16 +57,19 @@ def retriever(query):
{step_back_context}
Original Question: {question}
-Answer:"""
+Answer:""" # noqa: E501
response_prompt = ChatPromptTemplate.from_template(response_prompt_template)
-chain = {
- # Retrieve context using the normal question
- "normal_context": RunnableLambda(lambda x: x['question']) | retriever,
- # Retrieve context using the step-back question
- "step_back_context": question_gen | retriever,
- # Pass on the question
- "question": lambda x: x["question"]
-} | response_prompt | ChatOpenAI(temperature=0) | StrOutputParser()
-
-
+chain = (
+ {
+ # Retrieve context using the normal question
+ "normal_context": RunnableLambda(lambda x: x["question"]) | retriever,
+ # Retrieve context using the step-back question
+ "step_back_context": question_gen | retriever,
+ # Pass on the question
+ "question": lambda x: x["question"],
+ }
+ | response_prompt
+ | ChatOpenAI(temperature=0)
+ | StrOutputParser()
+)
diff --git a/templates/xml-agent/xml_agent/agent.py b/templates/xml-agent/xml_agent/agent.py
index cf1d203155790..ae638c3b929f8 100644
--- a/templates/xml-agent/xml_agent/agent.py
+++ b/templates/xml-agent/xml_agent/agent.py
@@ -1,13 +1,18 @@
-from langchain.chat_models import ChatAnthropic
-from langchain.tools.render import render_text_description
-from langchain.agents.format_scratchpad import format_xml
+from typing import List, Tuple
+
from langchain.agents import AgentExecutor
-from langchain.retrievers.you import YouRetriever
-from langchain.agents.agent_toolkits.conversational_retrieval.tool import create_retriever_tool
+from langchain.agents.agent_toolkits.conversational_retrieval.tool import (
+ create_retriever_tool,
+)
+from langchain.agents.format_scratchpad import format_xml
+from langchain.chat_models import ChatAnthropic
from langchain.pydantic_v1 import BaseModel
-from xml_agent.prompts import conversational_prompt, parse_output
+from langchain.retrievers.you import YouRetriever
from langchain.schema import AIMessage, HumanMessage
-from typing import List, Tuple
+from langchain.tools.render import render_text_description
+
+from xml_agent.prompts import conversational_prompt, parse_output
+
def _format_chat_history(chat_history: List[Tuple[str, str]]):
buffer = []
@@ -21,7 +26,9 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
# Fake Tool
retriever = YouRetriever(k=5)
-retriever_tool = create_retriever_tool(retriever, "search", "Use this to search for current events.")
+retriever_tool = create_retriever_tool(
+ retriever, "search", "Use this to search for current events."
+)
tools = [retriever_tool]
@@ -31,18 +38,25 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
)
llm_with_stop = model.bind(stop=[""])
-agent = {
- "question": lambda x: x["question"],
- "agent_scratchpad": lambda x: format_xml(x['intermediate_steps']),
- "chat_history": lambda x: _format_chat_history(x["chat_history"]),
-} | prompt | llm_with_stop | parse_output
+agent = (
+ {
+ "question": lambda x: x["question"],
+ "agent_scratchpad": lambda x: format_xml(x["intermediate_steps"]),
+ "chat_history": lambda x: _format_chat_history(x["chat_history"]),
+ }
+ | prompt
+ | llm_with_stop
+ | parse_output
+)
+
class AgentInput(BaseModel):
question: str
chat_history: List[Tuple[str, str]]
-agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True).with_types(
- input_type=AgentInput
-)
+
+agent_executor = AgentExecutor(
+ agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
+).with_types(input_type=AgentInput)
agent_executor = agent_executor | (lambda x: x["output"])
diff --git a/templates/xml-agent/xml_agent/prompts.py b/templates/xml-agent/xml_agent/prompts.py
index a9919010060b3..3dbf96353aa33 100644
--- a/templates/xml-agent/xml_agent/prompts.py
+++ b/templates/xml-agent/xml_agent/prompts.py
@@ -27,14 +27,16 @@
It is 64 degress in SF
-Begin!"""
+Begin!""" # noqa: E501
-conversational_prompt = ChatPromptTemplate.from_messages([
- ("system", template),
- MessagesPlaceholder(variable_name="chat_history"),
- ("user", "{question}"),
- ("ai", "{agent_scratchpad}")
-])
+conversational_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", template),
+ MessagesPlaceholder(variable_name="chat_history"),
+ ("user", "{question}"),
+ ("ai", "{agent_scratchpad}"),
+ ]
+)
def parse_output(message):
@@ -47,4 +49,4 @@ def parse_output(message):
_tool_input = _tool_input.split("")[0]
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
else:
- return AgentFinish(return_values={"output": text}, log=text)
\ No newline at end of file
+ return AgentFinish(return_values={"output": text}, log=text)