Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjoyo committed Aug 15, 2023
1 parent 09875ee commit 4cf0245
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 30 deletions.
7 changes: 5 additions & 2 deletions python/src/gpt/agents/retrieval_agent/retrieval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo

from gpt.agents.common.agent.base import Agent
from gpt.agents.common.agent.openai_functions.openai_functions_agent import OpenAIFunctionsAgent
from gpt.chains.retrieval_chain.chain import get_embeddings, get_vector_store
from gpt.chains.retrieval_chain.prompt import MULTI_QUERY_PROMPT
Expand Down Expand Up @@ -42,24 +43,26 @@ def json_schema_to_pydantic_model(name: str, schema: Dict[str, Any]) -> Any:

def create_retrieval_agent(
llm: BaseChatModel,
database: str,
database_url: str,
embedding_provider: str,
embedding_model: str,
output_schema: Optional[Dict[str, Union[str, dict]]] = None,
) -> Chain:
) -> Agent:
agent = OpenAIFunctionsAgent.create(
llm=llm,
)

embeddings = get_embeddings(embedding_provider, embedding_model)
vector_store = get_vector_store(database_url, embeddings)
vector_store = get_vector_store(database, database_url, embeddings)

# rephrase query multiple times and get union of docs
# multi_retriever = MultiQueryRetriever.from_llm(
# retriever=vector_store.as_retriever(),
# llm=llm,
# prompt=MULTI_QUERY_PROMPT
# )

# answer synthesizer
retrieval_qa = RetrievalQA.from_chain_type(
llm=llm,
Expand Down
12 changes: 6 additions & 6 deletions python/src/gpt/chains/retrieval_chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
from gpt.chains.support.sub_query_retriever.chain import SubQueryRetriever


def get_vector_store(database_url: str, embeddings: Embeddings, meta_attributes: Optional[List[str]] = None):
db, url = database_url.split('://', 1)
match db:
def get_vector_store(database: str, database_url: str, embeddings: Embeddings, meta_attributes: Optional[List[str]] = None):
match database:
case 'weaviate':
base_url, index = url.rsplit('/', 1)
base_url, index = database_url.rsplit('/', 1)
return Weaviate(
client=_create_weaviate_client(weaviate_url=base_url),
index_name=index,
Expand All @@ -28,7 +27,7 @@ def get_vector_store(database_url: str, embeddings: Embeddings, meta_attributes:
by_text=False
)
case _:
raise Exception(f'Unsupported vector database {db} in url {database_url}')
raise Exception(f'Unsupported vector database {database}.')


def get_embeddings(embedding_provider: str, embedding_model: str) -> Embeddings:
Expand All @@ -41,14 +40,15 @@ def get_embeddings(embedding_provider: str, embedding_model: str) -> Embeddings:

def create_legacy_retrieval_chain(
llm: BaseLanguageModel,
database: str,
database_url: str,
embedding_provider: str,
embedding_model: str,
mode: str = 'standard'
) -> Chain:

embeddings = get_embeddings(embedding_provider, embedding_model)
vector_store = get_vector_store(database_url, embeddings)
vector_store = get_vector_store(database, database_url, embeddings)

# rephrase query multiple times and get union of docs
multi_retriever = MultiQueryRetriever.from_llm(
Expand Down
36 changes: 22 additions & 14 deletions python/src/gpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from gpt.agents.plan_and_execute.executor.executor import create_executor
from gpt.agents.plan_and_execute.planner.planner import create_planner
from gpt.agents.process_generation_agent.process_generation_agent import create_process_generation_agent
from gpt.agents.retrieval_agent.retrieval_agent import create_retrieval_agent
from gpt.chains.compose_chain.chain import create_compose_chain
from gpt.chains.decide_chain.chain import create_decide_chain
from gpt.chains.extract_chain.chain import create_extract_chain
from gpt.chains.generic_chain.chain import create_generic_chain
from gpt.chains.retrieval_chain.chain import create_retrieval_chain, get_vector_store
from gpt.chains.retrieval_chain.chain import get_vector_store
from gpt.chains.translate_chain.chain import create_translate_chain
from gpt.config import model_id_to_llm
from gpt.server.types import RetrievalTask, ComposeTask, GenericTask, TranslateTask, DecideTask, ExtractTask, \
Expand Down Expand Up @@ -91,11 +92,12 @@ async def post(task: GenericTask):
@app.post("/openapi")
async def post(task: OpenApiTask):
if task.skill_store_url:
skill_store = get_vector_store(
task.skill_store_url,
OpenAIEmbeddings(),
meta_attributes=['task', 'comment', 'function', 'example_call']
)
# skill_store = get_vector_store(
# task.skill_store_url,
# OpenAIEmbeddings(),
# meta_attributes=['task', 'comment', 'function', 'example_call']
# )
skill_store = None
else:
skill_store = None

Expand All @@ -113,11 +115,12 @@ async def post(task: OpenApiTask):
@app.post("/database")
async def post(task: DatabaseTask):
if task.skill_store_url:
skill_store = get_vector_store(
task.skill_store_url,
OpenAIEmbeddings(),
meta_attributes=['task', 'comment', 'function', 'example_call']
)
# skill_store = get_vector_store(
# task.skill_store_url,
# OpenAIEmbeddings(),
# meta_attributes=['task', 'comment', 'function', 'example_call']
# )
skill_store = None
else:
skill_store = None

Expand All @@ -134,14 +137,19 @@ async def post(task: DatabaseTask):

@app.post("/retrieval")
async def post(task: RetrievalTask):
chain = create_retrieval_chain(
agent = create_retrieval_agent(
llm=model_id_to_llm(task.model),
database=task.database,
database_url=task.database_url,
embedding_provider=task.embedding_provider,
embedding_model=task.embedding_model,
mode=task.mode
output_schema=task.output_schema
)
return chain.run(query=task.query)
result = agent.run(input=task.query, context="")
if task.output_schema:
return result["output"]
else:
return result["output"]["answer"]


@app.post("/process")
Expand Down
2 changes: 2 additions & 0 deletions python/src/gpt/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ class DatabaseTask(BaseModel):

class RetrievalTask(BaseModel):
model: str
database: str
database_url: str
embedding_provider: str
embedding_model: str
mode: str
query: str
output_schema: Optional[dict] = None


class ProcessTask(BaseModel):
Expand Down
12 changes: 8 additions & 4 deletions python/tests/manual_integration/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def test_index_test_docs():

def test_create_skill_index():
vs = get_vector_store(
'weaviate://http://localhost:8080/SkillLibrary',
'weaviate',
'http://localhost:8080/SkillLibrary',
OpenAIEmbeddings(),
)
vs._client.schema.create_class({
Expand Down Expand Up @@ -318,7 +319,8 @@ def test_create_skill_index():

def test_clear_skills():
vs = get_vector_store(
'weaviate://http://localhost:8080/SkillLibrary',
'weaviate',
'http://localhost:8080/SkillLibrary',
OpenAIEmbeddings(),
meta_attributes=['task', 'comment', 'function', 'example_call']
)
Expand All @@ -328,7 +330,8 @@ def test_clear_skills():
def test_retrieve():
qa = create_legacy_retrieval_chain(
llm=get_openai_chat_llm(),
database_url='weaviate://http://localhost:8080/Test_index',
database='weaviate',
database_url='http://localhost:8080/Test_index',
embedding_provider="openai",
embedding_model="text-embedding-ada-002"
)
Expand All @@ -341,7 +344,8 @@ def test_flare_instruct():
llm = get_openai_chat_llm(model_name='gpt-4')

retriever = get_vector_store(
'weaviate://http://localhost:8080/Test_index',
'weaviate',
'http://localhost:8080/Test_index',
OpenAIEmbeddings()
).as_retriever()

Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_ui/agent_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def memory():
return AgentMemory()

skill_store = get_vector_store(
'weaviate://http://localhost:8080/SkillLibrary',
'weaviate',
'http://localhost:8080/SkillLibrary',
OpenAIEmbeddings(),
meta_attributes=['task', 'comment', 'function', 'example_call']
)
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_ui/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def get_accounts():
]

skill_store = get_vector_store(
'weaviate://http://localhost:8080/SkillLibrary',
'weaviate',
'http://localhost:8080/SkillLibrary',
OpenAIEmbeddings(),
meta_attributes=['task', 'comment', 'function', 'example_call']
)
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_ui/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def llama():
llm = llama()

vs = get_vector_store(
'weaviate://http://localhost:8080/Test_index',
'weaviate',
'http://localhost:8080/Test_index',
OpenAIEmbeddings(),
)

Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_ui/skill_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
st.title('📚 Skill Library')

vector_store = get_vector_store(
'weaviate://http://localhost:8080/SkillLibrary',
'weaviate',
'http://localhost:8080/SkillLibrary',
OpenAIEmbeddings(),
meta_attributes=['task', 'comment', 'function', 'example_call']
)
Expand Down

0 comments on commit 4cf0245

Please sign in to comment.