diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 0a59a6aa..f35c1566 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -19,7 +19,7 @@ rich_as_completed, validate_output, ) -from docetl.utils import completion_cost +from docetl.utils import completion_cost, extract_jinja_variables def compare_pair( @@ -433,7 +433,43 @@ def process_cluster(cluster): ) return [], reduction_cost else: - return [input_data[list(cluster)[0]]], 0 + # Set the output schema to be the keys found in the compare_prompt + compare_prompt_keys = extract_jinja_variables( + self.config["comparison_prompt"] + ) + # Get the set of keys in the compare_prompt + compare_prompt_keys = set( + [ + k.replace("input1.", "") + for k in compare_prompt_keys + if "input1" in k + ] + ) + + # For each key in the output schema, find the most similar key in the compare_prompt + output_keys = set(self.config["output"]["schema"].keys()) + key_mapping = {} + for output_key in output_keys: + best_match = None + best_score = 0 + for compare_key in compare_prompt_keys: + score = sum( + c1 == c2 for c1, c2 in zip(output_key, compare_key) + ) / max(len(output_key), len(compare_key)) + if score > best_score: + best_score = score + best_match = compare_key + key_mapping[output_key] = best_match + + # Create the result dictionary using the key mapping + result = input_data[list(cluster)[0]].copy() + for output_key, compare_key in key_mapping.items(): + if compare_key in input_data[list(cluster)[0]]: + result[output_key] = input_data[list(cluster)[0]][compare_key] + else: + result[output_key] = None # or some default value + + return [result], 0 # Calculate the number of records before and clusters after num_records_before = len(input_data) diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index f75d171d..5b1322a3 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -527,28 +527,27 @@ def call_llm_with_cache( Returns: str: The response from the LLM. """ - if tools is None: - props = {key: convert_val(value) for key, value in output_schema.items()} + props = {key: convert_val(value) for key, value in output_schema.items()} + use_tools = True + + if ( + len(props) == 1 + and list(props.values())[0].get("type") == "string" + and scratchpad is None + and "ollama" in model + ): + use_tools = False + + if tools is None and use_tools: if scratchpad is not None: props["updated_scratchpad"] = {"type": "string"} parameters = {"type": "object", "properties": props} parameters["required"] = list(props.keys()) - parameters["additionalProperties"] = False - - # response_format = { - # "type": "json_schema", - # "json_schema": { - # "name": "write_output", - # "description": "Write task output to a database", - # "strict": True, - # "schema": parameters, - # # "additionalProperties": False, - # }, - # } - - # tools = [] - # tool_choice = "auto" + + # TODO: this is a hack to get around the fact that gemini doesn't support additionalProperties + if "gemini" not in model: + parameters["additionalProperties"] = False tools = [ { @@ -563,15 +562,17 @@ def call_llm_with_cache( } ] tool_choice = {"type": "function", "function": {"name": "send_output"}} - response_format = None - else: + elif tools is not None: tools = json.loads(tools) tool_choice = ( "required" if any(tool.get("required", False) for tool in tools) else "auto" ) tools = [{"type": "function", "function": tool["function"]} for tool in tools] - response_format = None + + else: + tools = None + tool_choice = None system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation. You will perform the specified task on the provided data. The result should be a structured output that you will send back to the user." if scratchpad: @@ -599,7 +600,7 @@ def call_llm_with_cache( # Truncate messages if they exceed the model's context length messages = truncate_messages(messages, model) - if response_format is None: + if tools is not None: response = completion( model=model, messages=[ @@ -622,7 +623,6 @@ def call_llm_with_cache( }, ] + messages, - response_format=response_format, ) return response @@ -634,9 +634,6 @@ def truncate_messages( """ Truncate the messages to fit the model's context length. """ - if "gpt" not in model: - model = "gpt-4o" - model_input_context_length = model_cost.get(model, {}).get("max_input_tokens", 8192) total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages) @@ -882,9 +879,21 @@ def parse_llm_response_helper( Raises: InvalidOutputError: If the response is not valid. """ + if not response: raise InvalidOutputError("No response from LLM", [{}], schema, [], []) + tool_calls = ( + response.choices[0].message.tool_calls + if "tool_calls" in dir(response.choices[0].message) + else [] + ) + + # Check if there are no tools and the schema has a single key-value pair + if not tools and len(schema) == 1 and not tool_calls: + key = next(iter(schema)) + return [{key: response.choices[0].message.content}] + # Parse the response based on the provided tools if tools: # If custom tools are provided, parse accordingly @@ -907,54 +916,48 @@ def parse_llm_response_helper( results.append(function_args) return results else: - if "tool_calls" in dir(response.choices[0].message): - # Default behavior for write_output function - tool_calls = response.choices[0].message.tool_calls - - if not tool_calls: - raise InvalidOutputError( - "No tool calls in LLM response", [{}], schema, response.choices, [] - ) + if not tool_calls: + raise InvalidOutputError( + "No tool calls in LLM response", [{}], schema, response.choices, [] + ) - outputs = [] - for tool_call in tool_calls: - try: - output_dict = json.loads(tool_call.function.arguments) - if "ollama" in response.model: - for key, value in output_dict.items(): - if not isinstance(value, str): - continue + outputs = [] + for tool_call in tool_calls: + try: + output_dict = json.loads(tool_call.function.arguments) + if "ollama" in response.model: + for key, value in output_dict.items(): + if not isinstance(value, str): + continue + try: + output_dict[key] = ast.literal_eval(value) + except: try: - output_dict[key] = ast.literal_eval(value) + if value.startswith("["): + output_dict[key] = ast.literal_eval(value + "]") + else: + output_dict[key] = value except: - try: - if value.startswith("["): - output_dict[key] = ast.literal_eval(value + "]") - else: - output_dict[key] = value - except: - pass - outputs.append(output_dict) - except json.JSONDecodeError: - raise InvalidOutputError( - "Could not decode LLM JSON response", - [tool_call.function.arguments], - schema, - response.choices, - tools, - ) - except Exception as e: - raise InvalidOutputError( - f"Error parsing LLM response: {e}", - [tool_call.function.arguments], - schema, - response.choices, - tools, - ) - return outputs + pass + outputs.append(output_dict) + except json.JSONDecodeError: + raise InvalidOutputError( + "Could not decode LLM JSON response", + [tool_call.function.arguments], + schema, + response.choices, + tools, + ) + except Exception as e: + raise InvalidOutputError( + f"Error parsing LLM response: {e}", + [tool_call.function.arguments], + schema, + response.choices, + tools, + ) - else: - return [json.loads(response.choices[0].message.content)] + return outputs # message = response.choices[0].message # return [json.loads(message.content)] diff --git a/docs/best-practices.md b/docs/best-practices.md index 7a75bcf7..1886c3ea 100644 --- a/docs/best-practices.md +++ b/docs/best-practices.md @@ -2,6 +2,25 @@ This guide outlines best practices for using DocETL effectively, focusing on the most important aspects of pipeline creation, execution, and optimization. +!!! info "Supported Models" + + DocETL supports many models through LiteLLM: + + - OpenAI models (e.g., GPT-4, GPT-3.5-turbo) + - Anthropic models (e.g., Claude 2, Claude Instant) + - Google VertexAI models (e.g., chat-bison, text-bison) + - Cohere models + - Replicate models + - Azure OpenAI models + - Hugging Face models + - AWS Bedrock models (e.g., Claude, AI21, Cohere) + - Gemini models (e.g., gemini-1.5-pro) + - Ollama models (e.g., llama2) + + For a complete and up-to-date list of supported models, please refer to the [LiteLLM documentation](https://docs.litellm.ai/docs/providers). You can use the model name just like the litellm documentation (e.g., `openai/gpt-4o-mini` or `gemini/gemini-1.5-flash-002`). + + While DocETL supports various models, it has been primarily tested with OpenAI's language models. Using OpenAI is currently recommended for the best experience and most reliable results, especially for operations that depend on structured outputs. We have also tried gemini-1.5-flash-002 and found it to be pretty good for a much cheaper price. + ## Pipeline Design 1. **Start Simple**: Begin with a basic pipeline and gradually add complexity as needed. diff --git a/docs/tutorial.md b/docs/tutorial.md index 5a16fa7e..eb7d1a44 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -30,6 +30,8 @@ DocETL uses [LiteLLM](https://github.com/BerriAI/litellm) under the hood, which If you choose to use a different provider, be aware that you may encounter unexpected behavior or reduced functionality, especially with operations that depend on structured outputs. We use tool calling to extract structured outputs from the LLM's response, so make sure your provider supports tool calling. + If using a Gemini model, you can use the `gemini` prefix for the model name. For example, `gemini/gemini-1.5-flash-002`. (This has worked pretty well for us so far, and is so cheap!) + If using Ollama (e.g., llama 3.2), make sure your output schemas are not too complex, since these models are not as good as OpenAI for structured outputs! For example, use [parallel map operations](operators/parallel-map.md) to reduce the number of output attributes per prompt. ## Preparing the Data diff --git a/poetry.lock b/poetry.lock index 2f7088cc..63a1eb2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1534,13 +1534,13 @@ requests = ">=2.20" [[package]] name = "litellm" -version = "1.48.7" +version = "1.48.10" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.48.7-py3-none-any.whl", hash = "sha256:4971a9e681188635c2ee6dc44fe35bb2774586e9018682adcccdbb516b839c64"}, - {file = "litellm-1.48.7.tar.gz", hash = "sha256:ff1fef7049e9afa09598f98d1e510a6d5f252ec65c0526b8bfaf13eadfcf65e5"}, + {file = "litellm-1.48.10-py3-none-any.whl", hash = "sha256:752efd59747a0895f4695d025c66f0b2258d80a61175f7cfa41dbe4894ef95e1"}, + {file = "litellm-1.48.10.tar.gz", hash = "sha256:0a4ff75da78e66baeae0658ad8de498298310a5efda74c3d840ce2b013e8401d"}, ] [package.dependencies] diff --git a/tests/basic/test_basic_map.py b/tests/basic/test_basic_map.py index fc8e68b7..231d6a01 100644 --- a/tests/basic/test_basic_map.py +++ b/tests/basic/test_basic_map.py @@ -22,6 +22,7 @@ def test_map_operation( map_sample_data, ): results, cost = test_map_operation_instance.execute(map_sample_data) + print(results) assert len(results) == len(map_sample_data) assert all("sentiment" in result for result in results)