-
Notifications
You must be signed in to change notification settings - Fork 365
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: hongyishi <shihongyi88@gmail.com>
- Loading branch information
Showing
2 changed files
with
77 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from nemoguardrails import LLMRails, RailsConfig | ||
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, LLMPredictor | ||
from llama_index.indices.query.base import BaseQueryEngine | ||
from langchain.llms.base import BaseLLM | ||
|
||
from typing import Callable | ||
|
||
COLANG_CONFIG = """ | ||
define user express greeting | ||
"hi" | ||
define user express ill intent | ||
"I hate you" | ||
"I want to destroy the world" | ||
define bot express cannot respond | ||
"I'm sorry I cannot help you with that." | ||
define user express question | ||
"What is the current unemployment rate?" | ||
# Basic guardrail example | ||
define flow | ||
user express ill intent | ||
bot express cannot respond | ||
# Question answering flow | ||
define flow | ||
user express question | ||
$answer = execute llama_index_query(query=$last_user_message) | ||
bot $answer | ||
""" | ||
|
||
YAML_CONFIG = """ | ||
models: | ||
- type: main | ||
engine: openai | ||
model: text-davinci-003 | ||
""" | ||
|
||
|
||
def _get_llama_index_query_engine(llm: BaseLLM): | ||
docs = SimpleDirectoryReader( | ||
input_files=["../examples/grounding_rail/kb/report.md"] | ||
).load_data() | ||
llm_predictor = LLMPredictor(llm=llm) | ||
index = GPTVectorStoreIndex.from_documents(docs, llm_predictor=llm_predictor) | ||
default_query_engine = index.as_query_engine() | ||
return default_query_engine | ||
|
||
|
||
def _get_callable_query_engine( | ||
query_engine: BaseQueryEngine | ||
) -> Callable[[str], str]: | ||
async def get_query_response(query: str) -> str: | ||
return query_engine.query(query).response | ||
|
||
return get_query_response | ||
|
||
|
||
def demo(): | ||
config = RailsConfig.from_content(COLANG_CONFIG, YAML_CONFIG) | ||
app = LLMRails(config) | ||
query_engine: BaseQueryEngine = _get_llama_index_query_engine(app.llm) | ||
app.register_action( | ||
_get_callable_query_engine(query_engine), name="llama_index_query" | ||
) | ||
|
||
history = [{"role": "user", "content": "What is the current unemployment rate?"}] | ||
result = app.generate(messages=history) | ||
print(result) | ||
|
||
|
||
if __name__ == "__main__": | ||
demo() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ uvicorn==0.21.1 | |
httpx==0.23.3 | ||
simpleeval==0.9.13 | ||
typing-extensions==4.5.0 | ||
llama_index==0.6.14 |