Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: logprops for vertex and deepinfra streaming response #896

Merged
merged 6 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/providers/deepseek/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ interface DeepSeekStreamChunk {
object: string;
created: number;
model: string;
usage?: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
};
choices: {
delta: {
role?: string | null;
Expand Down Expand Up @@ -161,6 +166,7 @@ export const DeepSeekChatCompleteResponseTransform: (
export const DeepSeekChatCompleteStreamChunkTransform: (
response: string
) => string = (responseChunk) => {
console.log('responseChunk', responseChunk);
narengogi marked this conversation as resolved.
Show resolved Hide resolved
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
Expand All @@ -182,6 +188,7 @@ export const DeepSeekChatCompleteStreamChunkTransform: (
finish_reason: parsedChunk.choices[0].finish_reason,
},
],
usage: parsedChunk.usage,
})}` + '\n\n'
);
};
25 changes: 24 additions & 1 deletion src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
import {
ChatCompletionResponse,
ErrorResponse,
Logprobs,
ProviderConfig,
} from '../types';
import {
Expand All @@ -40,7 +41,11 @@ import type {
VertexLLamaChatCompleteResponse,
GoogleSearchRetrievalTool,
} from './types';
import { getMimeType, recursivelyDeleteUnsupportedParameters } from './utils';
import {
getMimeType,
recursivelyDeleteUnsupportedParameters,
transformVertexLogprobs,
} from './utils';

export const buildGoogleSearchRetrievalTool = (tool: Tool) => {
const googleSearchRetrievalTool: GoogleSearchRetrievalTool = {
Expand Down Expand Up @@ -247,6 +252,14 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
logprobs: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
top_logprobs: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes
// Example payload to be included in the request that sets the safety settings:
// "safety_settings": [
Expand Down Expand Up @@ -682,10 +695,20 @@ export const GoogleChatCompleteResponseTransform: (
}),
};
}
const logprobsContent: Logprobs[] | null =
transformVertexLogprobs(generation);
let logprobs;
if (logprobsContent) {
logprobs = {
content: logprobsContent,
};
}

return {
message: message,
index: index,
finish_reason: generation.finishReason,
logprobs,
...(!strictOpenAiCompliance && {
safetyRatings: generation.safetyRatings,
}),
Expand Down
6 changes: 6 additions & 0 deletions src/providers/google-vertex-ai/transformGenerationConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ export function transformGenerationConfig(params: Params) {
if (params?.response_format?.type === 'json_object') {
generationConfig['responseMimeType'] = 'application/json';
}
if (params['logprobs']) {
generationConfig['responseLogprobs'] = params['logprobs'];
}
if (params['top_logprobs']) {
generationConfig['logprobs'] = params['top_logprobs']; // range 1-5, openai supports 1-20
}
if (params?.response_format?.type === 'json_schema') {
generationConfig['responseMimeType'] = 'application/json';
recursivelyDeleteUnsupportedParameters(
Expand Down
82 changes: 51 additions & 31 deletions src/providers/google-vertex-ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,60 @@ export interface GoogleGenerateFunctionCall {
args: Record<string, any>;
}

export interface GoogleGenerateContentResponse {
candidates: {
content: {
parts: {
text?: string;
thought?: string; // for models like gemini-2.0-flash-thinking-exp refer: https://ai.google.dev/gemini-api/docs/thinking-mode#streaming_model_thinking
functionCall?: GoogleGenerateFunctionCall;
}[];
};
finishReason: string;
index: 0;
safetyRatings: {
category: string;
probability: string;
export interface GoogleResponseCandidate {
content: {
parts: {
text?: string;
thought?: string; // for models like gemini-2.0-flash-thinking-exp refer: https://ai.google.dev/gemini-api/docs/thinking-mode#streaming_model_thinking
functionCall?: GoogleGenerateFunctionCall;
}[];
groundingMetadata?: {
webSearchQueries?: string[];
searchEntryPoint?: {
renderedContent: string;
};
groundingSupports?: Array<{
segment: {
startIndex: number;
endIndex: number;
text: string;
};
groundingChunkIndices: number[];
confidenceScores: number[];
}>;
retrievalMetadata?: {
webDynamicRetrievalScore: number;
};
logprobsResult?: {
topCandidates: [
{
candidates: [
{
token: string;
logProbability: number;
},
];
},
];
chosenCandidates: [
{
token: string;
logProbability: number;
},
];
};
finishReason: string;
index: 0;
safetyRatings: {
category: string;
probability: string;
}[];
groundingMetadata?: {
webSearchQueries?: string[];
searchEntryPoint?: {
renderedContent: string;
};
groundingSupports?: Array<{
segment: {
startIndex: number;
endIndex: number;
text: string;
};
groundingChunkIndices: number[];
confidenceScores: number[];
}>;
retrievalMetadata?: {
webDynamicRetrievalScore: number;
};
}[];
};
}

export interface GoogleGenerateContentResponse {
candidates: GoogleResponseCandidate[];
promptFeedback: {
safetyRatings: {
category: string;
Expand Down
44 changes: 42 additions & 2 deletions src/providers/google-vertex-ai/utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { GoogleErrorResponse } from './types';
import { GoogleErrorResponse, GoogleResponseCandidate } from './types';
import { generateErrorResponse } from '../utils';
import { fileExtensionMimeTypeMap, GOOGLE_VERTEX_AI } from '../../globals';
import { ErrorResponse } from '../types';
import { ErrorResponse, Logprobs } from '../types';

/**
* Encodes an object as a Base64 URL-encoded string.
Expand Down Expand Up @@ -220,3 +220,43 @@ export const recursivelyDeleteUnsupportedParameters = (obj: any) => {
}
}
};

export const transformVertexLogprobs = (
generation: GoogleResponseCandidate
) => {
let logprobsContent: Logprobs[] = [];
if (!generation.logprobsResult) return null;
if (generation.logprobsResult?.chosenCandidates) {
generation.logprobsResult.chosenCandidates.forEach((candidate) => {
let bytes = [];
for (const char of candidate.token) {
bytes.push(char.charCodeAt(0));
}
logprobsContent.push({
token: candidate.token,
logprob: candidate.logProbability,
bytes: bytes,
});
});
}
if (generation.logprobsResult?.topCandidates) {
generation.logprobsResult.topCandidates.forEach(
(topCandidatesForIndex, index) => {
let topLogprobs = [];
for (const candidate of topCandidatesForIndex.candidates) {
let bytes = [];
for (const char of candidate.token) {
bytes.push(char.charCodeAt(0));
}
topLogprobs.push({
token: candidate.token,
logprob: candidate.logProbability,
bytes: bytes,
});
}
logprobsContent[index].top_logprobs = topLogprobs;
}
);
}
return logprobsContent;
};
107 changes: 76 additions & 31 deletions src/providers/google/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import {
derefer,
getMimeType,
recursivelyDeleteUnsupportedParameters,
transformVertexLogprobs,
} from '../google-vertex-ai/utils';
import {
ChatCompletionResponse,
ErrorResponse,
Logprobs,
ProviderConfig,
} from '../types';
import {
Expand Down Expand Up @@ -47,6 +49,12 @@ const transformGenerationConfig = (params: Params) => {
if (params?.response_format?.type === 'json_object') {
generationConfig['responseMimeType'] = 'application/json';
}
if (params['logprobs']) {
generationConfig['responseLogprobs'] = params['logprobs'];
}
if (params['top_logprobs']) {
generationConfig['logprobs'] = params['top_logprobs']; // range 1-5, openai supports 1-20
}
if (params?.response_format?.type === 'json_schema') {
generationConfig['responseMimeType'] = 'application/json';
recursivelyDeleteUnsupportedParameters(
Expand Down Expand Up @@ -331,6 +339,14 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
logprobs: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
top_logprobs: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
tools: {
param: 'tools',
default: '',
Expand Down Expand Up @@ -397,40 +413,60 @@ interface GoogleGenerateFunctionCall {
args: Record<string, any>;
}

interface GoogleGenerateContentResponse {
candidates: {
content: {
parts: {
text?: string;
thought?: string; // for models like gemini-2.0-flash-thinking-exp refer: https://ai.google.dev/gemini-api/docs/thinking-mode#streaming_model_thinking
functionCall?: GoogleGenerateFunctionCall;
}[];
};
finishReason: string;
index: 0;
safetyRatings: {
category: string;
probability: string;
interface GoogleResponseCandidate {
content: {
parts: {
text?: string;
thought?: string; // for models like gemini-2.0-flash-thinking-exp refer: https://ai.google.dev/gemini-api/docs/thinking-mode#streaming_model_thinking
functionCall?: GoogleGenerateFunctionCall;
}[];
groundingMetadata?: {
webSearchQueries?: string[];
searchEntryPoint?: {
renderedContent: string;
};
groundingSupports?: Array<{
segment: {
startIndex: number;
endIndex: number;
text: string;
};
groundingChunkIndices: number[];
confidenceScores: number[];
}>;
retrievalMetadata?: {
webDynamicRetrievalScore: number;
};
logprobsResult?: {
topCandidates: [
{
candidates: [
{
token: string;
logProbability: number;
},
];
},
];
chosenCandidates: [
{
token: string;
logProbability: number;
},
];
};
finishReason: string;
index: 0;
safetyRatings: {
category: string;
probability: string;
}[];
groundingMetadata?: {
webSearchQueries?: string[];
searchEntryPoint?: {
renderedContent: string;
};
groundingSupports?: Array<{
segment: {
startIndex: number;
endIndex: number;
text: string;
};
groundingChunkIndices: number[];
confidenceScores: number[];
}>;
retrievalMetadata?: {
webDynamicRetrievalScore: number;
};
}[];
};
}

interface GoogleGenerateContentResponse {
candidates: GoogleResponseCandidate[];
promptFeedback: {
safetyRatings: {
category: string;
Expand Down Expand Up @@ -528,8 +564,17 @@ export const GoogleChatCompleteResponseTransform: (
}),
};
}
const logprobsContent: Logprobs[] | null =
transformVertexLogprobs(generation);
let logprobs;
if (logprobsContent) {
logprobs = {
content: logprobsContent,
};
}
return {
message: message,
logprobs,
index: generation.index ?? idx,
finish_reason: generation.finishReason,
...(!strictOpenAiCompliance && generation.groundingMetadata
Expand Down
Loading
Loading