From 848015d3f493d38695d1ffc4c517d83ab2009122 Mon Sep 17 00:00:00 2001 From: hongyishi Date: Mon, 5 Jun 2023 15:33:05 -0700 Subject: [PATCH] Hongyi LlamaIndex NeMo demo --- examples/demo_llama_index_guardrails.py | 94 +++++++++++++++++++++++++ requirements.txt | 2 +- 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 examples/demo_llama_index_guardrails.py diff --git a/examples/demo_llama_index_guardrails.py b/examples/demo_llama_index_guardrails.py new file mode 100644 index 000000000..1601e8309 --- /dev/null +++ b/examples/demo_llama_index_guardrails.py @@ -0,0 +1,94 @@ +from nemoguardrails import LLMRails, RailsConfig +from langchain.llms.base import BaseLLM + +from typing import Callable, Any, Coroutine + +COLANG_CONFIG = """ +define user express greeting + "hi" + +define user express ill intent + "I hate you" + "I want to destroy the world" + +define bot express cannot respond + "I'm sorry I cannot help you with that." + +define user express question + "What is the current unemployment rate?" + +# Basic guardrail example +define flow + user express ill intent + bot express cannot respond + +# Question answering flow +define flow + user ... + $answer = execute llama_index_query(query=$last_user_message) + bot $answer + +""" + +YAML_CONFIG = """ +models: + - type: main + engine: openai + model: text-davinci-003 +""" + + +def demo(): + try: + import llama_index + from llama_index.indices.query.base import BaseQueryEngine + from llama_index.response.schema import StreamingResponse + + except ImportError: + raise ImportError( + "Could not import llama_index, please install it with " + "`pip install llama_index`." + ) + + config = RailsConfig.from_content(COLANG_CONFIG, YAML_CONFIG) + app = LLMRails(config) + + def _get_llama_index_query_engine(llm: BaseLLM): + docs = llama_index.SimpleDirectoryReader( + input_files=["../examples/grounding_rail/kb/report.md"] + ).load_data() + llm_predictor = llama_index.LLMPredictor(llm=llm) + index = llama_index.GPTVectorStoreIndex.from_documents( + docs, llm_predictor=llm_predictor + ) + default_query_engine = index.as_query_engine() + return default_query_engine + + def _get_callable_query_engine( + query_engine: BaseQueryEngine, + ) -> Callable[[str], Coroutine[Any, Any, str]]: + async def get_query_response(query: str) -> str: + response = query_engine.query(query) + if isinstance(response, StreamingResponse): + typed_response = response.get_response() + else: + typed_response = response + response_str = typed_response.response + if response_str is None: + return "" + return response_str + + return get_query_response + + query_engine = _get_llama_index_query_engine(app.llm) + app.register_action( + _get_callable_query_engine(query_engine), name="llama_index_query" + ) + + history = [{"role": "user", "content": "What is the current unemployment rate?"}] + result = app.generate(messages=history) + print(result) + + +if __name__ == "__main__": + demo() diff --git a/requirements.txt b/requirements.txt index b5980e788..b241d3f98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ starlette==0.26.1 uvicorn==0.21.1 httpx==0.23.3 simpleeval==0.9.13 -typing-extensions==4.5.0 +typing-extensions==4.5.0 \ No newline at end of file