Skip to content

Commit

Permalink
Merge pull request #681 from b4s36t4/feat/vertex-controlled-output
Browse files Browse the repository at this point in the history
feat: allow pydantic support for vertex controlled output generation
  • Loading branch information
VisargD authored Oct 22, 2024
2 parents d09632f + 0ca0376 commit 1cda698
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/providers/google-vertex-ai/transformGenerationConfig.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Params } from '../../types/requestBody';

import { derefer } from './utils';
/**
* @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#request_body
*/
Expand Down Expand Up @@ -28,7 +28,14 @@ export function transformGenerationConfig(params: Params) {
}
if (params?.response_format?.type === 'json_schema') {
generationConfig['responseMimeType'] = 'application/json';
generationConfig['responseSchema'] = params?.response_format.json_schema;
let schema =
params?.response_format?.json_schema?.schema ??
params?.response_format?.json_schema;
if (Object.keys(schema).includes('$defs')) {
schema = derefer(schema);
delete schema['$defs'];
}
generationConfig['responseSchema'] = schema;
}

return generationConfig;
Expand Down
36 changes: 36 additions & 0 deletions src/providers/google-vertex-ai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,39 @@ export const GoogleErrorResponseTransform: (

return undefined;
};

const getDefFromRef = (ref: string) => {
const refParts = ref.split('/');
return refParts.at(-1);
};

const getRefParts = (spec: Record<string, any>, ref: string) => {
return spec?.[ref];
};

export const derefer = (spec: Record<string, any>, defs = null) => {
const original = { ...spec };

const finalDefs = defs ?? original?.['$defs'];
const entries = Object.entries(original);

for (let [key, object] of entries) {
if (key === '$defs') {
continue;
}
if (typeof object === 'string' || Array.isArray(object)) {
continue;
}
const ref = object?.['$ref'];
if (ref) {
const def = getDefFromRef(ref);
const defData = getRefParts(finalDefs, def ?? '');
const newValue = derefer(defData, finalDefs);
original[key] = newValue;
} else {
const newValue = derefer(object, finalDefs);
original[key] = newValue;
}
}
return original;
};

0 comments on commit 1cda698

Please sign in to comment.