Skip to content

Commit

Permalink
Merge pull request #59 from ucbepic/shreyashankar/oss
Browse files Browse the repository at this point in the history
feat: don't use tool calling for ollama/OSS models if the output schema is just one param
  • Loading branch information
shreyashankar authored Oct 3, 2024
2 parents 51d4cf1 + 5a7dc66 commit fcc3368
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 74 deletions.
40 changes: 38 additions & 2 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
rich_as_completed,
validate_output,
)
from docetl.utils import completion_cost
from docetl.utils import completion_cost, extract_jinja_variables


def compare_pair(
Expand Down Expand Up @@ -433,7 +433,43 @@ def process_cluster(cluster):
)
return [], reduction_cost
else:
return [input_data[list(cluster)[0]]], 0
# Set the output schema to be the keys found in the compare_prompt
compare_prompt_keys = extract_jinja_variables(
self.config["comparison_prompt"]
)
# Get the set of keys in the compare_prompt
compare_prompt_keys = set(
[
k.replace("input1.", "")
for k in compare_prompt_keys
if "input1" in k
]
)

# For each key in the output schema, find the most similar key in the compare_prompt
output_keys = set(self.config["output"]["schema"].keys())
key_mapping = {}
for output_key in output_keys:
best_match = None
best_score = 0
for compare_key in compare_prompt_keys:
score = sum(
c1 == c2 for c1, c2 in zip(output_key, compare_key)
) / max(len(output_key), len(compare_key))
if score > best_score:
best_score = score
best_match = compare_key
key_mapping[output_key] = best_match

# Create the result dictionary using the key mapping
result = input_data[list(cluster)[0]].copy()
for output_key, compare_key in key_mapping.items():
if compare_key in input_data[list(cluster)[0]]:
result[output_key] = input_data[list(cluster)[0]][compare_key]
else:
result[output_key] = None # or some default value

return [result], 0

# Calculate the number of records before and clusters after
num_records_before = len(input_data)
Expand Down
141 changes: 72 additions & 69 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,28 +527,27 @@ def call_llm_with_cache(
Returns:
str: The response from the LLM.
"""
if tools is None:
props = {key: convert_val(value) for key, value in output_schema.items()}
props = {key: convert_val(value) for key, value in output_schema.items()}
use_tools = True

if (
len(props) == 1
and list(props.values())[0].get("type") == "string"
and scratchpad is None
and "ollama" in model
):
use_tools = False

if tools is None and use_tools:
if scratchpad is not None:
props["updated_scratchpad"] = {"type": "string"}

parameters = {"type": "object", "properties": props}
parameters["required"] = list(props.keys())
parameters["additionalProperties"] = False

# response_format = {
# "type": "json_schema",
# "json_schema": {
# "name": "write_output",
# "description": "Write task output to a database",
# "strict": True,
# "schema": parameters,
# # "additionalProperties": False,
# },
# }

# tools = []
# tool_choice = "auto"

# TODO: this is a hack to get around the fact that gemini doesn't support additionalProperties
if "gemini" not in model:
parameters["additionalProperties"] = False

tools = [
{
Expand All @@ -563,15 +562,17 @@ def call_llm_with_cache(
}
]
tool_choice = {"type": "function", "function": {"name": "send_output"}}
response_format = None

else:
elif tools is not None:
tools = json.loads(tools)
tool_choice = (
"required" if any(tool.get("required", False) for tool in tools) else "auto"
)
tools = [{"type": "function", "function": tool["function"]} for tool in tools]
response_format = None

else:
tools = None
tool_choice = None

system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation. You will perform the specified task on the provided data. The result should be a structured output that you will send back to the user."
if scratchpad:
Expand Down Expand Up @@ -599,7 +600,7 @@ def call_llm_with_cache(
# Truncate messages if they exceed the model's context length
messages = truncate_messages(messages, model)

if response_format is None:
if tools is not None:
response = completion(
model=model,
messages=[
Expand All @@ -622,7 +623,6 @@ def call_llm_with_cache(
},
]
+ messages,
response_format=response_format,
)

return response
Expand All @@ -634,9 +634,6 @@ def truncate_messages(
"""
Truncate the messages to fit the model's context length.
"""
if "gpt" not in model:
model = "gpt-4o"

model_input_context_length = model_cost.get(model, {}).get("max_input_tokens", 8192)
total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages)

Expand Down Expand Up @@ -882,9 +879,21 @@ def parse_llm_response_helper(
Raises:
InvalidOutputError: If the response is not valid.
"""

if not response:
raise InvalidOutputError("No response from LLM", [{}], schema, [], [])

tool_calls = (
response.choices[0].message.tool_calls
if "tool_calls" in dir(response.choices[0].message)
else []
)

# Check if there are no tools and the schema has a single key-value pair
if not tools and len(schema) == 1 and not tool_calls:
key = next(iter(schema))
return [{key: response.choices[0].message.content}]

# Parse the response based on the provided tools
if tools:
# If custom tools are provided, parse accordingly
Expand All @@ -907,54 +916,48 @@ def parse_llm_response_helper(
results.append(function_args)
return results
else:
if "tool_calls" in dir(response.choices[0].message):
# Default behavior for write_output function
tool_calls = response.choices[0].message.tool_calls

if not tool_calls:
raise InvalidOutputError(
"No tool calls in LLM response", [{}], schema, response.choices, []
)
if not tool_calls:
raise InvalidOutputError(
"No tool calls in LLM response", [{}], schema, response.choices, []
)

outputs = []
for tool_call in tool_calls:
try:
output_dict = json.loads(tool_call.function.arguments)
if "ollama" in response.model:
for key, value in output_dict.items():
if not isinstance(value, str):
continue
outputs = []
for tool_call in tool_calls:
try:
output_dict = json.loads(tool_call.function.arguments)
if "ollama" in response.model:
for key, value in output_dict.items():
if not isinstance(value, str):
continue
try:
output_dict[key] = ast.literal_eval(value)
except:
try:
output_dict[key] = ast.literal_eval(value)
if value.startswith("["):
output_dict[key] = ast.literal_eval(value + "]")
else:
output_dict[key] = value
except:
try:
if value.startswith("["):
output_dict[key] = ast.literal_eval(value + "]")
else:
output_dict[key] = value
except:
pass
outputs.append(output_dict)
except json.JSONDecodeError:
raise InvalidOutputError(
"Could not decode LLM JSON response",
[tool_call.function.arguments],
schema,
response.choices,
tools,
)
except Exception as e:
raise InvalidOutputError(
f"Error parsing LLM response: {e}",
[tool_call.function.arguments],
schema,
response.choices,
tools,
)
return outputs
pass
outputs.append(output_dict)
except json.JSONDecodeError:
raise InvalidOutputError(
"Could not decode LLM JSON response",
[tool_call.function.arguments],
schema,
response.choices,
tools,
)
except Exception as e:
raise InvalidOutputError(
f"Error parsing LLM response: {e}",
[tool_call.function.arguments],
schema,
response.choices,
tools,
)

else:
return [json.loads(response.choices[0].message.content)]
return outputs

# message = response.choices[0].message
# return [json.loads(message.content)]
Expand Down
19 changes: 19 additions & 0 deletions docs/best-practices.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@

This guide outlines best practices for using DocETL effectively, focusing on the most important aspects of pipeline creation, execution, and optimization.

!!! info "Supported Models"

DocETL supports many models through LiteLLM:

- OpenAI models (e.g., GPT-4, GPT-3.5-turbo)
- Anthropic models (e.g., Claude 2, Claude Instant)
- Google VertexAI models (e.g., chat-bison, text-bison)
- Cohere models
- Replicate models
- Azure OpenAI models
- Hugging Face models
- AWS Bedrock models (e.g., Claude, AI21, Cohere)
- Gemini models (e.g., gemini-1.5-pro)
- Ollama models (e.g., llama2)

For a complete and up-to-date list of supported models, please refer to the [LiteLLM documentation](https://docs.litellm.ai/docs/providers). You can use the model name just like the litellm documentation (e.g., `openai/gpt-4o-mini` or `gemini/gemini-1.5-flash-002`).

While DocETL supports various models, it has been primarily tested with OpenAI's language models. Using OpenAI is currently recommended for the best experience and most reliable results, especially for operations that depend on structured outputs. We have also tried gemini-1.5-flash-002 and found it to be pretty good for a much cheaper price.

## Pipeline Design

1. **Start Simple**: Begin with a basic pipeline and gradually add complexity as needed.
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ DocETL uses [LiteLLM](https://github.com/BerriAI/litellm) under the hood, which

If you choose to use a different provider, be aware that you may encounter unexpected behavior or reduced functionality, especially with operations that depend on structured outputs. We use tool calling to extract structured outputs from the LLM's response, so make sure your provider supports tool calling.

If using a Gemini model, you can use the `gemini` prefix for the model name. For example, `gemini/gemini-1.5-flash-002`. (This has worked pretty well for us so far, and is so cheap!)

If using Ollama (e.g., llama 3.2), make sure your output schemas are not too complex, since these models are not as good as OpenAI for structured outputs! For example, use [parallel map operations](operators/parallel-map.md) to reduce the number of output attributes per prompt.

## Preparing the Data
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/basic/test_basic_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_map_operation(
map_sample_data,
):
results, cost = test_map_operation_instance.execute(map_sample_data)
print(results)

assert len(results) == len(map_sample_data)
assert all("sentiment" in result for result in results)
Expand Down

0 comments on commit fcc3368

Please sign in to comment.