diff --git a/libs/aws/langchain_aws/runnables/__init__.py b/libs/aws/langchain_aws/runnables/__init__.py new file mode 100644 index 00000000..dbc787e2 --- /dev/null +++ b/libs/aws/langchain_aws/runnables/__init__.py @@ -0,0 +1,3 @@ +from langchain_aws.runnables.q_business import AmazonQ + +__all__ = ["AmazonQ"] \ No newline at end of file diff --git a/libs/aws/langchain_aws/runnables/q_business.py b/libs/aws/langchain_aws/runnables/q_business.py new file mode 100644 index 00000000..d1ad9e92 --- /dev/null +++ b/libs/aws/langchain_aws/runnables/q_business.py @@ -0,0 +1,153 @@ +import logging +from typing import Any, Dict, Optional + +from langchain_core._api.beta_decorator import beta +from langchain_core.runnables import Runnable +from langchain_core.runnables.config import RunnableConfig +from pydantic import ConfigDict +from typing_extensions import Self + + +@beta(message="This API is in beta and can change in future.") +class AmazonQ(Runnable[str, str]): + """Amazon Q Runnable wrapper. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + Make sure the credentials / roles used have the required policies to + access the Amazon Q service. + """ + + region_name: Optional[str] = None + """AWS region name. If not provided, will be extracted from environment.""" + + credentials: Optional[Any] = None + """Amazon Q credentials used to instantiate the client if the client is not provided.""" + + client: Optional[Any] = None + """Amazon Q client.""" + + application_id: str = None + + _last_response: Dict = None # Add this to store the full response + """Store the full response from Amazon Q.""" + + parent_message_id: Optional[str] = None + + conversation_id: Optional[str] = None + + chat_mode: str = "RETRIEVAL_MODE" + + model_config = ConfigDict( + extra="forbid", + ) + + def __init__( + self, + region_name: Optional[str] = None, + credentials: Optional[Any] = None, + client: Optional[Any] = None, + application_id: str = None, + parent_message_id: Optional[str] = None, + conversation_id: Optional[str] = None, + chat_mode: str = "RETRIEVAL_MODE", + ): + self.region_name = region_name + self.credentials = credentials + self.client = client or self.validate_environment() + self.application_id = application_id + self.parent_message_id = parent_message_id + self.conversation_id = conversation_id + self.chat_mode = chat_mode + + def invoke( + self, + input: str, + config: Optional[RunnableConfig] = None, + **kwargs: Any + ) -> str: + """Call out to Amazon Q service. + + Args: + input: The prompt to pass into the model. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + model = AmazonQ( + credentials=your_credentials, + application_id=your_app_id + ) + response = model.invoke("Tell me a joke") + """ + try: + # Prepare the request + request = { + 'applicationId': self.application_id, + 'userMessage': input, + 'chatMode': self.chat_mode, + } + if self.conversation_id: + request.update({ + 'conversationId': self.conversation_id, + 'parentMessageId': self.parent_message_id, + }) + + # Call Amazon Q + response = self.client.chat_sync(**request) + self._last_response = response + + # Extract the response text + if 'systemMessage' in response: + return response["systemMessage"] + else: + raise ValueError("Unexpected response format from Amazon Q") + + except Exception as e: + if "Prompt Length" in str(e): + logging.info(f"Prompt Length: {len(input)}") + print(f"""Prompt: + {input}""") + raise ValueError(f"Error raised by Amazon Q service: {e}") + + def get_last_response(self) -> Dict: + """Method to access the full response from the last call""" + return self._last_response + + def validate_environment(self) -> Self: + """Don't do anything if client provided externally""" + #If the client is not provided, and the user_id is not provided in the class constructor, throw an error saying one or the other needs to be provided + if self.credentials is None: + raise ValueError( + "Either the credentials or the client needs to be provided." + ) + + """Validate that AWS credentials to and python package exists in environment.""" + try: + import boto3 + + try: + if self.region_name is not None: + client = boto3.client('qbusiness', self.region_name, **self.credentials) + else: + # use default region + client = boto3.client('qbusiness', **self.credentials) + + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + except ImportError: + raise ImportError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + return client