diff --git a/tests/unit/vertexai/test_prompt.py b/tests/unit/vertexai/test_prompt.py new file mode 100644 index 00000000000..e59c6ca4385 --- /dev/null +++ b/tests/unit/vertexai/test_prompt.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Unit tests for generative model prompts.""" +# pylint: disable=protected-access,bad-continuation + +from vertexai.generative_models._prompt import Prompt +from vertexai.generative_models import Content, Part + +import pytest + +from typing import Any + + +def _is_list_type(obj: Any, T: Any) -> bool: + return isinstance(obj, list) and all(isinstance(s, T) for s in obj) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPrompt: + """Unit tests for generative model prompts.""" + + def test_inferred_prompt_constructor(self): + # Create inferred prompt with string only variable values + prompt = Prompt( + prompt_data="Rate the movie {movie1}", + variables={ + "movie1": "The Avengers", + }, + ) + # Inferred prompt data should remain as string before compilation + assert prompt.prompt_data == "Rate the movie {movie1}" + # Variables values should be converted to List[Part] + assert _is_list_type(prompt.variables["movie1"], Part) + + # Create inferred prompt with List[Part] variable values + prompt = Prompt( + prompt_data="Rate the movie {movie1}", + variables={ + "movie1": [Part.from_text("The Avengers")], + }, + ) + # Variables values should be converted to List[Part] + assert _is_list_type(prompt.variables["movie1"], Part) + + # Inferred prompt variables must either be string or List[Part] + with pytest.raises(ValueError): + Prompt( + prompt_data="Rate the movie {movie1}", + variables={ + "movie1": Part.from_text("The Avengers"), + }, + ) + + def test_inferred_prompt_to_content(self): + prompt = Prompt( + prompt_data="Which movie is better, {movie1} or {movie2}?", + variables={ + "movie1": "The Avengers", + "movie2": "Frozen", + }, + ) + unassembled_prompt_content = prompt.to_content() + expected_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, {movie1} or {movie2}?"), + ], + role="user", + ) + ] + assert unassembled_prompt_content[0].role == expected_content[0].role + for i in range(len(unassembled_prompt_content[0].parts)): + assert ( + unassembled_prompt_content[0].parts[i].text + == expected_content[0].parts[i].text + ) + + # Check assembled prompt content + prompt.assemble() + assembled_prompt_content = prompt.to_content() + expected_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, "), + Part.from_text("The Avengers"), + Part.from_text(" or "), + Part.from_text("Frozen"), + Part.from_text("?"), + ], + role="user", + ) + ] + assert assembled_prompt_content[0].role == expected_content[0].role + for i in range(len(assembled_prompt_content[0].parts)): + assert ( + assembled_prompt_content[0].parts[i].text + == expected_content[0].parts[i].text + ) + + def test_inferred_prompt_assemble(self): + prompt = Prompt( + prompt_data="Which movie is better, {movie1} or {movie2}?", + variables={ + "movie1": "The Avengers", + }, + ) + + # Check partially assembled prompt content + prompt.assemble() + assembled1_prompt_content = prompt.to_content() + expected1_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, "), + Part.from_text("The Avengers"), + Part.from_text(" or "), + Part.from_text("{movie2}"), + Part.from_text("?"), + ], + role="user", + ) + ] + assert assembled1_prompt_content[0].role == expected1_content[0].role + for i in range(len(assembled1_prompt_content[0].parts)): + assert ( + assembled1_prompt_content[0].parts[i].text + == expected1_content[0].parts[i].text + ) + + # Check assembled prompt content + prompt.assemble(movie2="Frozen") + assembled2_prompt_content = prompt.to_content() + expected2_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, "), + Part.from_text("The Avengers"), + Part.from_text(" or "), + Part.from_text("Frozen"), + Part.from_text("?"), + ], + role="user", + ) + ] + assert assembled2_prompt_content[0].role == expected2_content[0].role + for i in range(len(assembled2_prompt_content[0].parts)): + assert ( + assembled2_prompt_content[0].parts[i].text + == expected2_content[0].parts[i].text + ) + + # Check assembled prompt content with override + prompt.assemble(movie1="Inception") + assembled3_prompt_content = prompt.to_content() + expected3_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, "), + Part.from_text("Inception"), + Part.from_text(" or "), + Part.from_text("Frozen"), + Part.from_text("?"), + ], + role="user", + ) + ] + assert assembled3_prompt_content[0].role == expected3_content[0].role + for i in range(len(assembled3_prompt_content[0].parts)): + assert ( + assembled3_prompt_content[0].parts[i].text + == expected3_content[0].parts[i].text + ) diff --git a/vertexai/generative_models/_prompt.py b/vertexai/generative_models/_prompt.py new file mode 100644 index 00000000000..cb6af4ef895 --- /dev/null +++ b/vertexai/generative_models/_prompt.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vertexai.generative_models import Content, Part + +import re +from typing import ( + Any, + Dict, + List, + Union, +) + +VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})" + + +class ExpandablePartsVariable: + """A List[Parts] variable that can be expanded to multiple parts. + + The name field should be the name of the variable, not wrapped with {}. This + class wraps a text Part with the name of the variable wrapped with {}. + + Usage: + ``` + prompt = Prompt( + prompt_data=[Content( + parts=[ + Part.from_text("What do these pictures have in common?"), + ExpandablePartsVariable(name="pictures"), + ], + role="user" + )], + variables={"pictures": [ + Part.from_uri(uri="gs://.../img1.jpg", mime_type="image/jpeg"), + Part.from_uri(uri="gs://.../img2.jpg", mime_type="image/jpeg"), + ]} + ) + ``` + """ + + def __init__(self, name: str): + """Initializes the ExpandablePartsVariable with a name. + + Args: + name: The name of the variable, not wrapped with {}. + """ + self.name = name + self._text_part = Part.from_text(f"{{{name}}}") + self._raw_part = self._text_part._raw_part + + def to_part(self) -> Part: + """Returns the underlying text Part of the ExpandablePartsVariable. + The text is the name of the variable wrapped with {}. + """ + return self._text_part + + +class Prompt: + """A prompt which may be a template with variables. + + The `Prompt` class allows users to define a template string with + variables represented in curly braces `{variable}`. The variable + name must be a valid Python variable name (no spaces, must start with a + letter). These placeholders can be replaced with specific values using the + `assemble` method, providing flexibility in generating dynamic prompts. + + Usage: + + ``` + prompt_data = + prompt = Prompt( + prompt_data="Hello, {name}! Today is {day}. How are you?", + variables={"name": "Alice", "day": "Monday"} + ) + prompt.assemble() + prompt.to_content() + ``` + """ + + def __init__( + self, + prompt_data: Union[str, List[Content]], + variables: Dict[str, Union[str, List[Part]]], + ): + """Initializes the Prompt with a given prompt, and variables. + + Args: + prompt: A string or List[Content] prompt which may be a template with variables or a prompt with no variables. + variables: A dictionary containing the variable names and values. + """ + self._check_valid_prompt_data(prompt_data) + self.prompt_data = prompt_data + self._check_valid_variables(variables) + self.variables = self._format_variables(variables) + self._assembled_prompt = None + + def set_prompt(self, prompt_data: Union[str, List[Content]]) -> None: + """Overwrites the existing prompt_data. + + Args: + prompt_data: A string or List[Content] prompt. + """ + self._check_valid_prompt_data(prompt_data) + self.prompt_data = prompt_data + self._assembled_prompt = None + + def set_variables(self, variables: Dict[str, Union[str, List[Part]]]) -> None: + """Overwrites the existing variables dictionary. + + Args: + variables: A dictionary containing the variable names and values. + """ + self._check_valid_variables(variables) + self.variables = self._format_variables(variables) + self._assembled_prompt = None + + def _format_variables( + self, variables: Dict[str, Union[str, List[Part]]] + ) -> Dict[str, List[Part]]: + """Formats the variables values to be List[Part].""" + for key in variables.keys(): + if isinstance(variables[key], str): + variables[key] = [Part.from_text(variables[key])] + return variables + + def _is_list_type(self, obj: Any, T: Any) -> bool: + return isinstance(obj, list) and all(isinstance(s, T) for s in obj) + + def _check_valid_prompt_data(self, prompt_data: Any) -> None: + if not ( + isinstance(prompt_data, str) or self._is_list_type(prompt_data, Content) + ): + raise ValueError( + "Prompt data must be a string or a list of Content objects." + ) + + def _check_valid_variables(self, variables: Any) -> None: + """Dict must be a Dict[str, Union[str, List[Part]].""" + valid = True + if isinstance(variables, dict): + for key, value in variables.items(): + if not isinstance(key, str): + valid = False + if not (isinstance(value, str) or self._is_list_type(value, Part)): + valid = False + else: + valid = False + if not valid: + raise ValueError( + "Variable values must be a string or a list of Part objects." + ) + + def assemble(self, **kwargs) -> None: + """Replaces only the provided variables in the template with specific values. + + Args: + **kwargs: Keyword arguments where keys are placeholder names and values + are the replacements. + + Returns: + A new PromptTemplate instance with the updated template string. + + Usage: + ``` + prompt = Prompt( + prompt_data="Hello, {name}! Today is {day}. How are you?", + variables={"name": "Alice"} + ) + prompt.assemble() + prompt.to_content() + # Returns "Hello, Alice! Today is {day}. How are you?" as List[Content] + prompt.assemble(day="Monday") + prompt.to_content() + # Returns "Hello, Alice! Today is Monday. How are you?" as List[Content] + prompt.assemble(name="Bob") + prompt.to_content() + # Returns "Hello, Bob! Today is Monday. How are you?" as List[Content] + ``` + """ + # Python Dict update will overwrite existing key values. + self.variables.update(kwargs) + + # Convert the variables values to List[Part]. + formatted_variables = self._format_variables(self.variables) + self._check_valid_variables(formatted_variables) + + # Convert inferred prompt type (str) to static type (List[Content]) + if isinstance(self.prompt_data, str): + # Step 1) Find and isolate variables as their own string. + prompt_data_str_split = re.split(VARIABLE_NAME_REGEX, self.prompt_data) + + # Step 2) Replace each list element as a Part or ExpandablePartsVariable. + for i in range(len(prompt_data_str_split)): + if re.match(VARIABLE_NAME_REGEX, prompt_data_str_split[i]): + prompt_data_str_split[i] = ExpandablePartsVariable( + name=prompt_data_str_split[i][1:-1] + ) + else: + prompt_data_str_split[i] = Part.from_text(prompt_data_str_split[i]) + + # Step 3) Wrap List[Part] as a single Content object. + list_content_prompt_data = [ + Content( + parts=prompt_data_str_split, + role="user", + ) + ] + elif self._is_list_type(self.prompt_data, Content): + list_content_prompt_data = self.prompt_data + else: + raise ValueError( + "Prompt data must be a string or a list of Content objects." + ) + + # Assemble the List[Content] prompt + for i in range(len(list_content_prompt_data)): + content = list_content_prompt_data[i] + expanded_parts = [] + for part in content.parts: + # Infer part of format {variable_name} to be a ExpandablePartsVariable. + if ( + re.match(VARIABLE_NAME_REGEX, part.text) + and part.text[1:-1] in formatted_variables + ): + expanded_parts.extend(formatted_variables[part.text[1:-1]]) + else: + expanded_parts.append(part) + + list_content_prompt_data[i] = Content( + parts=expanded_parts, role=content.role + ) + self._assembled_prompt = list_content_prompt_data + + def to_content(self) -> List[Content]: + """Returns the prompt data, assembled if prompt.assemble() was called. + Can be ingested into model.generate_content to make API calls. + + Returns: + A List[Content] prompt. + Usage: + ``` + prompt = Prompt( + prompt_data="Hello, {name}! Today is {day}. How are you?", + variables={"name": "Alice", "day": "Monday"} + ) + prompt.assemble(day="Monday") + model.generate_content( + contents=prompt.to_content() + ) + ``` + """ + if self._assembled_prompt: + # self._assembled_prompt must be List[Content] + return self._assembled_prompt + elif self._is_list_type(self.prompt_data, Content): + # If prompt_data is a List[Content], return it + return self.prompt_data + elif isinstance(self.prompt_data, str): + # If prompt_data is a string, wrap it with a Content object with singleton Part. + return [ + Content( + parts=[Part.from_text(self.prompt_data)], + role="user", + ) + ] + else: + raise ValueError( + "Prompt data must be a string or a list of Content objects." + ) + + def get_unassembled_prompt(self) -> Union[List[Content], str]: + """Returns the prompt data, without any variables replaced.""" + return self.prompt_data + + def __str__(self) -> str: + """Returns the prompt data, assembled if prompt.assemble() was called.""" + return str(self._assembled_prompt) or self.prompt_data + + def __repr__(self) -> str: + """Returns a string representation of the unassembled prompt.""" + return f"Prompt(prompt_data='{self.prompt_data}', variables={self.variables})"