diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index e87b048e1366..4cdec20b908d 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -137,45 +137,41 @@ def create(self, params: Dict) -> ChatCompletion: streaming_tool_calls = [] ans = None - try: - response = client.chat.completions.create(**cerebras_params) - except Exception as e: - raise RuntimeError(f"Cerebras exception occurred: {e}") - else: - - if cerebras_params["stream"]: - # Read in the chunks as they stream, taking in tool_calls which may be across - # multiple chunks if more than one suggested - ans = "" - for chunk in response: - # Grab first choice, which _should_ always be generated. - ans = ans + (chunk.choices[0].delta.content or "") - - if chunk.choices[0].delta.tool_calls: - # We have a tool call recommendation - for tool_call in chunk.choices[0].delta.tool_calls: - streaming_tool_calls.append( - ChatCompletionMessageToolCall( - id=tool_call.id, - function={ - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - type="function", - ) + response = client.chat.completions.create(**cerebras_params) + + if cerebras_params["stream"]: + # Read in the chunks as they stream, taking in tool_calls which may be across + # multiple chunks if more than one suggested + ans = "" + for chunk in response: + # Grab first choice, which _should_ always be generated. + ans = ans + (getattr(chunk.choices[0].delta, "content", None) or "") + + if "tool_calls" in chunk.choices[0].delta: + # We have a tool call recommendation + for tool_call in chunk.choices[0].delta["tool_calls"]: + streaming_tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call["id"], + function={ + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + }, + type="function", ) + ) - if chunk.choices[0].finish_reason: - prompt_tokens = chunk.x_cerebras.usage.prompt_tokens - completion_tokens = chunk.x_cerebras.usage.completion_tokens - total_tokens = chunk.x_cerebras.usage.total_tokens - else: - # Non-streaming finished - ans: str = response.choices[0].message.content + if chunk.choices[0].finish_reason: + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens + total_tokens = chunk.usage.total_tokens + else: + # Non-streaming finished + ans: str = response.choices[0].message.content - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - total_tokens = response.usage.total_tokens + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + total_tokens = response.usage.total_tokens if response is not None: if isinstance(response, Stream): @@ -209,8 +205,6 @@ def create(self, params: Dict) -> ChatCompletion: response_content = response.choices[0].message.content response_id = response.id - else: - raise RuntimeError("Failed to get response from Cerebras after retrying 5 times.") # 3. convert output message = ChatCompletionMessage( diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 87916319d082..7997a33a39ea 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -49,6 +49,12 @@ ERROR = None try: + from cerebras.cloud.sdk import ( # noqa + AuthenticationError as cerebras_AuthenticationError, + InternalServerError as cerebras_InternalServerError, + RateLimitError as cerebras_RateLimitError, + ) + from autogen.oai.cerebras import CerebrasClient cerebras_import_exception: Optional[ImportError] = None @@ -868,6 +874,9 @@ def yes_or_no_filter(context, response): ollama_ResponseError, bedrock_BotoCoreError, bedrock_ClientError, + cerebras_AuthenticationError, + cerebras_InternalServerError, + cerebras_RateLimitError, ): logger.debug(f"config {i} failed", exc_info=True) if i == last: