Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hongyishi committed Jun 12, 2023
1 parent cebc4a6 commit 13bdf7c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
64 changes: 41 additions & 23 deletions examples/demo_llama_index_guardrails.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from nemoguardrails import LLMRails, RailsConfig
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, LLMPredictor
from llama_index.indices.query.base import BaseQueryEngine
from langchain.llms.base import BaseLLM

from typing import Callable
from typing import Callable, Any, Coroutine

COLANG_CONFIG = """
define user express greeting
Expand All @@ -26,7 +24,7 @@
# Question answering flow
define flow
user express question
user ...
$answer = execute llama_index_query(query=$last_user_message)
bot $answer
Expand All @@ -40,29 +38,49 @@
"""


def _get_llama_index_query_engine(llm: BaseLLM):
docs = SimpleDirectoryReader(
input_files=["../examples/grounding_rail/kb/report.md"]
).load_data()
llm_predictor = LLMPredictor(llm=llm)
index = GPTVectorStoreIndex.from_documents(docs, llm_predictor=llm_predictor)
default_query_engine = index.as_query_engine()
return default_query_engine


def _get_callable_query_engine(
query_engine: BaseQueryEngine
) -> Callable[[str], str]:
async def get_query_response(query: str) -> str:
return query_engine.query(query).response

return get_query_response
def demo():
try:
import llama_index
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.response.schema import StreamingResponse

except ImportError:
raise ImportError(
"Could not import llama_index, please install it with "
"`pip install llama_index`."
)

def demo():
config = RailsConfig.from_content(COLANG_CONFIG, YAML_CONFIG)
app = LLMRails(config)
query_engine: BaseQueryEngine = _get_llama_index_query_engine(app.llm)

def _get_llama_index_query_engine(llm: BaseLLM):
docs = llama_index.SimpleDirectoryReader(
input_files=["../examples/grounding_rail/kb/report.md"]
).load_data()
llm_predictor = llama_index.LLMPredictor(llm=llm)
index = llama_index.GPTVectorStoreIndex.from_documents(
docs, llm_predictor=llm_predictor
)
default_query_engine = index.as_query_engine()
return default_query_engine

def _get_callable_query_engine(
query_engine: BaseQueryEngine,
) -> Callable[[str], Coroutine[Any, Any, str]]:
async def get_query_response(query: str) -> str:
response = query_engine.query(query)
if isinstance(response, StreamingResponse):
typed_response = response.get_response()
else:
typed_response = response
response_str = typed_response.response
if response_str is None:
return ""
return response_str

return get_query_response

query_engine = _get_llama_index_query_engine(app.llm)
app.register_action(
_get_callable_query_engine(query_engine), name="llama_index_query"
)
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ starlette==0.26.1
uvicorn==0.21.1
httpx==0.23.3
simpleeval==0.9.13
typing-extensions==4.5.0
llama_index==0.6.14
typing-extensions==4.5.0

0 comments on commit 13bdf7c

Please sign in to comment.