Skip to content

Commit

Permalink
refactor(groq): return final Response (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisB-TL authored Feb 23, 2025
1 parent 6162343 commit 086a373
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 120 deletions.
7 changes: 4 additions & 3 deletions src/Providers/Groq/Groq.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse;
use EchoLabs\Prism\Providers\Groq\Handlers\Text;
use EchoLabs\Prism\Structured\Request as StructuredRequest;
use EchoLabs\Prism\Structured\Response as StructuredResponse;
use EchoLabs\Prism\Text\Request;
use EchoLabs\Prism\ValueObjects\ProviderResponse;
use EchoLabs\Prism\Text\Response as TextResponse;
use Illuminate\Http\Client\PendingRequest;
use Illuminate\Support\Facades\Http;

Expand All @@ -22,15 +23,15 @@ public function __construct(
) {}

#[\Override]
public function text(Request $request): ProviderResponse
public function text(Request $request): TextResponse
{
$handler = new Text($this->client($request->clientOptions(), $request->clientRetry()));

return $handler->handle($request);
}

#[\Override]
public function structured(StructuredRequest $request): ProviderResponse
public function structured(StructuredRequest $request): StructuredResponse
{
throw new \Exception(sprintf('%s does not support structured mode', class_basename($this)));
}
Expand Down
143 changes: 115 additions & 28 deletions src/Providers/Groq/Handlers/Text.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,91 @@

namespace EchoLabs\Prism\Providers\Groq\Handlers;

use EchoLabs\Prism\Concerns\CallsTools;
use EchoLabs\Prism\Enums\FinishReason;
use EchoLabs\Prism\Exceptions\PrismException;
use EchoLabs\Prism\Providers\Groq\Maps\FinishReasonMap;
use EchoLabs\Prism\Providers\Groq\Maps\MessageMap;
use EchoLabs\Prism\Providers\Groq\Maps\ToolChoiceMap;
use EchoLabs\Prism\Providers\Groq\Maps\ToolMap;
use EchoLabs\Prism\Text\Request;
use EchoLabs\Prism\ValueObjects\ProviderResponse;
use EchoLabs\Prism\Text\Response as TextResponse;
use EchoLabs\Prism\Text\ResponseBuilder;
use EchoLabs\Prism\Text\Step;
use EchoLabs\Prism\ValueObjects\Messages\AssistantMessage;
use EchoLabs\Prism\ValueObjects\Messages\ToolResultMessage;
use EchoLabs\Prism\ValueObjects\ResponseMeta;
use EchoLabs\Prism\ValueObjects\ToolCall;
use EchoLabs\Prism\ValueObjects\ToolResult;
use EchoLabs\Prism\ValueObjects\Usage;
use Illuminate\Http\Client\PendingRequest;
use Illuminate\Http\Client\Response;
use Throwable;

class Text
{
public function __construct(protected PendingRequest $client) {}
use CallsTools;

public function handle(Request $request): ProviderResponse
protected ResponseBuilder $responseBuilder;

public function __construct(protected PendingRequest $client)
{
$this->responseBuilder = new ResponseBuilder;
}

public function handle(Request $request): TextResponse
{
$data = $this->sendRequest($request);

$this->validateResponse($data);

$responseMessage = new AssistantMessage(
data_get($data, 'message.content') ?? '',
$this->mapToolCalls(data_get($data, 'choices.0.message.tool_calls', []) ?? []),
);

$this->responseBuilder->addResponseMessage($responseMessage);

$request->addMessage($responseMessage);

$finishReason = FinishReasonMap::map(data_get($data, 'choices.0.finish_reason', ''));

return match ($finishReason) {
FinishReason::ToolCalls => $this->handleToolCalls($data, $request),
FinishReason::Stop, FinishReason::Length => $this->handleStop($data, $request, $finishReason),
default => throw new PrismException('Groq: unhandled finish reason'),
};
}

/**
* @return array<string, mixed>
*/
protected function sendRequest(Request $request): array
{
try {
$response = $this->sendRequest($request);
$response = $this->client->post(
'chat/completions',
array_filter([
'model' => $request->model(),
'messages' => (new MessageMap($request->messages(), $request->systemPrompts()))(),
'max_tokens' => $request->maxTokens(),
'temperature' => $request->temperature(),
'top_p' => $request->topP(),
'tools' => ToolMap::map($request->tools()),
'tool_choice' => ToolChoiceMap::map($request->toolChoice()),
])
);

return $response->json();
} catch (Throwable $e) {
throw PrismException::providerRequestError($request->model(), $e);
}
}

$data = $response->json();

/**
* @param array<string, mixed> $data
*/
protected function validateResponse(array $data): void
{
if (! $data || data_get($data, 'error')) {
throw PrismException::providerResponseError(vsprintf(
'Groq Error: [%s] %s',
Expand All @@ -41,37 +98,67 @@ public function handle(Request $request): ProviderResponse
]
));
}
}

/**
* @param array<string, mixed> $data
*/
protected function handleToolCalls(array $data, Request $request): TextResponse
{
$toolResults = $this->callTools(
$request->tools(),
$this->mapToolCalls(data_get($data, 'choices.0.message.tool_calls', []) ?? []),
);

$request->addMessage(new ToolResultMessage($toolResults));

$this->addStep($data, $request, FinishReason::ToolCalls, $toolResults);

if ($this->shouldContinue($request)) {
return $this->handle($request);
}

return $this->responseBuilder->toResponse();
}

/**
* @param array<string, mixed> $data
*/
protected function handleStop(array $data, Request $request, FinishReason $finishReason): TextResponse
{
$this->addStep($data, $request, $finishReason);

return $this->responseBuilder->toResponse();
}

protected function shouldContinue(Request $request): bool
{
return $this->responseBuilder->steps->count() < $request->maxSteps();
}

return new ProviderResponse(
/**
* @param array<string, mixed> $data
* @param ToolResult[] $toolResults
*/
protected function addStep(array $data, Request $request, FinishReason $finishReason, array $toolResults = []): void
{
$this->responseBuilder->addStep(new Step(
text: data_get($data, 'choices.0.message.content') ?? '',
finishReason: $finishReason,
toolCalls: $this->mapToolCalls(data_get($data, 'choices.0.message.tool_calls', []) ?? []),
toolResults: $toolResults,
usage: new Usage(
data_get($data, 'usage.prompt_tokens'),
data_get($data, 'usage.completion_tokens'),
),
finishReason: FinishReasonMap::map(data_get($data, 'choices.0.finish_reason', '')),
responseMeta: new ResponseMeta(
id: data_get($data, 'id'),
model: data_get($data, 'model'),
)
);
}

public function sendRequest(Request $request): Response
{
return $this->client->post(
'chat/completions',
array_merge([
'model' => $request->model(),
'messages' => (new MessageMap($request->messages(), $request->systemPrompts()))(),
'max_tokens' => $request->maxTokens ?? 2048,
], array_filter([
'temperature' => $request->temperature(),
'top_p' => $request->topP(),
'tools' => ToolMap::map($request->tools()),
'tool_choice' => ToolChoiceMap::map($request->toolChoice()),
]))
);
),
messages: $request->messages(),
systemPrompts: $request->systemPrompts(),
additionalContent: [],
));
}

/**
Expand Down
48 changes: 1 addition & 47 deletions tests/Fixtures/groq/generate-text-with-multiple-tools-1.json
Original file line number Diff line number Diff line change
@@ -1,47 +1 @@
{
"id": "chatcmpl-2f85156f-4864-4621-a977-6767e74251b5",
"object": "chat.completion",
"created": 1730035768,
"model": "llama3-groq-70b-8192-tool-use-preview",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": [
{
"id": "call_4b20",
"type": "function",
"function": {
"name": "search",
"arguments": "{\"query\": \"tigers game today in Detroit\"}"
}
},
{
"id": "call_hvk7",
"type": "function",
"function": {
"name": "weather",
"arguments": "{\"city\": \"Detroit\"}"
}
}
]
},
"logprobs": null,
"finish_reason": "tool_calls"
}
],
"usage": {
"queue_time": 0.012132366000000002,
"prompt_tokens": 296,
"prompt_time": 0.02126043,
"completion_tokens": 55,
"completion_time": 0.173039707,
"total_tokens": 351,
"total_time": 0.194300137
},
"system_fingerprint": "fp_ee4b521143",
"x_groq": {
"id": "req_01jb72npr0fj0t9p423mrdt52h"
}
}
{"id":"chatcmpl-1cbf7b0f-c99a-4807-a820-e8d91c585015","object":"chat.completion","created":1740311145,"model":"llama-3.3-70b-versatile","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_3whd","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Detroit\"}"}},{"id":"call_6xxk","type":"function","function":{"name":"search","arguments":"{\"query\": \"Tigers game time today in Detroit\"}"}}]},"logprobs":null,"finish_reason":"tool_calls"}],"usage":{"queue_time":0.233586791,"prompt_tokens":310,"prompt_time":0.016113644,"completion_tokens":32,"completion_time":0.116363636,"total_tokens":342,"total_time":0.13247728},"system_fingerprint":"fp_7b42aeb9fa","x_groq":{"id":"req_01jmsa1b8tfmb9h9k7gh9tej8x"}}
31 changes: 1 addition & 30 deletions tests/Fixtures/groq/generate-text-with-multiple-tools-2.json
Original file line number Diff line number Diff line change
@@ -1,30 +1 @@
{
"id": "chatcmpl-e4daf477-4536-4f23-9c3e-de490185423f",
"object": "chat.completion",
"created": 1730035716,
"model": "llama3-groq-70b-8192-tool-use-preview",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "The Tigers game is at 3pm in Detroit. Given the weather is 75° and sunny, it's likely to be warm, so you might not need a coat. However, it's always a good idea to check the weather closer to the game time as it can change."
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"queue_time": 0.027946203000000003,
"prompt_tokens": 48,
"prompt_time": 0.006288822,
"completion_tokens": 59,
"completion_time": 0.187873195,
"total_tokens": 107,
"total_time": 0.194162017
},
"system_fingerprint": "fp_ee4b521143",
"x_groq": {
"id": "req_01jb72m451f7yr7cmtb0bb7yxs"
}
}
{"id":"chatcmpl-6c7d4a2f-5950-4430-bb08-4e12d7e94e04","object":"chat.completion","created":1740311147,"model":"llama-3.3-70b-versatile","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_hmv6","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Detroit\"}"}}]},"logprobs":null,"finish_reason":"tool_calls"}],"usage":{"queue_time":0.268621691,"prompt_tokens":377,"prompt_time":0.019224374,"completion_tokens":26,"completion_time":0.094545455,"total_tokens":403,"total_time":0.113769829},"system_fingerprint":"fp_0a4b7a8df3","x_groq":{"id":"req_01jmsa1c7cf6pb1efvnwzbkk8n"}}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"id":"chatcmpl-8288c3f5-e381-4ca1-8472-f926970b8392","object":"chat.completion","created":1740311147,"model":"llama-3.3-70b-versatile","choices":[{"index":0,"message":{"role":"assistant","content":"Based on the weather, you won't need a coat for the Tigers game today in Detroit. It's going to be 75° and sunny. The game starts at 3 pm."},"logprobs":null,"finish_reason":"stop"}],"usage":{"queue_time":0.281927537,"prompt_tokens":409,"prompt_time":0.021915109,"completion_tokens":39,"completion_time":0.141818182,"total_tokens":448,"total_time":0.163733291},"system_fingerprint":"fp_7b42aeb9fa","x_groq":{"id":"req_01jmsa1ct3frb8vhbq3ppbaky8"}}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"id":"chatcmpl-3c1a57b2-da9f-454d-82cd-4ab1b0cc2456","object":"chat.completion","created":1730556571,"model":"llama3-groq-70b-8192-tool-use-preview","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_2rps","type":"function","function":{"name":"weather","arguments":"{\"city\": \"New York\"}"}},{"id":"call_gzsm","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Los Angeles\"}"}},{"id":"call_b0by","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Chicago\"}"}},{"id":"call_h6gq","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Houston\"}"}},{"id":"call_zyrx","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Philadelphia\"}"}},{"id":"call_d427","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Phoenix\"}"}},{"id":"call_kd8c","type":"function","function":{"name":"weather","arguments":"{\"city\": \"San Antonio\"}"}},{"id":"call_w6w4","type":"function","function":{"name":"weather","arguments":"{\"city\": \"San Diego\"}"}},{"id":"call_7qsn","type":"function","function":{"name":"weather","arguments":"{\"city\": \"Dallas\"}"}},{"id":"call_p7mz","type":"function","function":{"name":"weather","arguments":"{\"city\": \"San Jose\"}"}}]},"logprobs":null,"finish_reason":"tool_calls"}],"usage":{"queue_time":0.012363021000000002,"prompt_tokens":299,"prompt_time":0.021131387,"completion_tokens":187,"completion_time":0.5974724,"total_tokens":486,"total_time":0.618603787},"system_fingerprint":"fp_ee4b521143","x_groq":{"id":"req_01jbpkbc6kfe2ag6btbxqypb4z"}}
{"id":"chatcmpl-753b3c7d-6f5b-4a0c-bf32-66d05e44ba1a","object":"chat.completion","created":1740311332,"model":"llama-3.3-70b-versatile","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_qc7n","type":"function","function":{"name":"weather","arguments":"{\"city\": \"New York\"}"}}]},"logprobs":null,"finish_reason":"tool_calls"}],"usage":{"queue_time":0.234223463,"prompt_tokens":312,"prompt_time":0.015721485,"completion_tokens":11,"completion_time":0.04,"total_tokens":323,"total_time":0.055721485},"system_fingerprint":"fp_5f849c5a0b","x_groq":{"id":"req_01jmsa71nke0zrnk8n4acyh6kh"}}
23 changes: 12 additions & 11 deletions tests/Providers/Groq/GroqTextTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
];

$response = Prism::text()
->using('groq', 'llama3-groq-70b-8192-tool-use-preview')
->using('groq', 'llama-3.3-70b-versatile')
->withTools($tools)
->withMaxSteps(3)
->withPrompt('What time is the tigers game today in Detroit and should I wear a coat?')
Expand All @@ -79,27 +79,28 @@
// Assert tool calls in the first step
$firstStep = $response->steps[0];
expect($firstStep->toolCalls)->toHaveCount(2);
expect($firstStep->toolCalls[0]->name)->toBe('search');

expect($firstStep->toolCalls[0]->name)->toBe('weather');
expect($firstStep->toolCalls[0]->arguments())->toBe([
'query' => 'tigers game today in Detroit',
'city' => 'Detroit',
]);

expect($firstStep->toolCalls[1]->name)->toBe('weather');
expect($firstStep->toolCalls[1]->name)->toBe('search');
expect($firstStep->toolCalls[1]->arguments())->toBe([
'city' => 'Detroit',
'query' => 'Tigers game time today in Detroit',
]);

// Assert usage
expect($response->usage->promptTokens)->toBe(344);
expect($response->usage->completionTokens)->toBe(114);
expect($response->usage->promptTokens)->toBe(1096);
expect($response->usage->completionTokens)->toBe(97);

// Assert response
expect($response->responseMeta->id)->toBe('chatcmpl-e4daf477-4536-4f23-9c3e-de490185423f');
expect($response->responseMeta->model)->toBe('llama3-groq-70b-8192-tool-use-preview');
expect($response->responseMeta->id)->toBe('chatcmpl-8288c3f5-e381-4ca1-8472-f926970b8392');
expect($response->responseMeta->model)->toBe('llama-3.3-70b-versatile');

// Assert final text content
expect($response->text)->toBe(
"The Tigers game is at 3pm in Detroit. Given the weather is 75° and sunny, it's likely to be warm, so you might not need a coat. However, it's always a good idea to check the weather closer to the game time as it can change."
"Based on the weather, you won't need a coat for the Tigers game today in Detroit. It's going to be 75° and sunny. The game starts at 3 pm."
);
});

Expand All @@ -118,7 +119,7 @@
];

$response = Prism::text()
->using(Provider::Groq, 'llama3-groq-70b-8192-tool-use-preview')
->using(Provider::Groq, 'llama-3.3-70b-versatile')
->withPrompt('Do something')
->withTools($tools)
->withToolChoice('weather')
Expand Down

0 comments on commit 086a373

Please sign in to comment.