diff --git a/plugins/default/default.test.ts b/plugins/default/default.test.ts index 56ecf27c7..80263e035 100644 --- a/plugins/default/default.test.ts +++ b/plugins/default/default.test.ts @@ -10,6 +10,7 @@ import { handler as logHandler } from './log'; import { handler as allUppercaseHandler } from './alluppercase'; import { handler as endsWithHandler } from './endsWith'; import { handler as allLowerCaseHandler } from './alllowercase'; +import { handler as modelWhitelistHandler } from './modelWhitelist'; import { z } from 'zod'; import { PluginContext, PluginParameters } from '../types'; @@ -802,3 +803,36 @@ describe('allLowercase handler', () => { expect(result.verdict).toBe(false); }); }); + +describe('modelWhitelist handler', () => { + it('should return true verdict when the model requested is part of the whitelist', async () => { + const context: PluginContext = { + request: { json: { model: 'gemini-1.5-flash-001' } }, + }; + + const parameters: PluginParameters = { + models: ['gemini-1.5-flash-001'], + }; + const eventType = 'beforeRequestHook'; + + const result = await modelWhitelistHandler(context, parameters, eventType); + + expect(result.error).toBe(null); + expect(result.verdict).toBe(true); + }); + it('should return false verdict when the model requested is not part of the whitelist', async () => { + const context: PluginContext = { + request: { json: { model: 'gemini-1.5-pro-001' } }, + }; + + const parameters: PluginParameters = { + models: ['gemini-1.5-flash-001'], + }; + const eventType = 'beforeRequestHook'; + + const result = await modelWhitelistHandler(context, parameters, eventType); + + expect(result.error).toBe(null); + expect(result.verdict).toBe(false); + }); +}); diff --git a/plugins/default/manifest.json b/plugins/default/manifest.json index 3f62d45d8..177c96fae 100644 --- a/plugins/default/manifest.json +++ b/plugins/default/manifest.json @@ -475,6 +475,37 @@ } ], "parameters": {} + }, + { + "name": "Model whitelisting", + "id": "modelwhitelist", + "type": "guardrail", + "supportedHooks": ["beforeRequestHook"], + "description": [ + { + "type": "subHeading", + "text": "Check if the model in the request is part of the allowed model list." + } + ], + "parameters": { + "type": "object", + "properties": { + "models": { + "type": "array", + "label": "Model list", + "description": [ + { + "type": "subHeading", + "text": "Enter the allowed models." + } + ], + "items": { + "type": "string" + } + } + }, + "required": ["models"] + } } ] } diff --git a/plugins/default/modelWhitelist.ts b/plugins/default/modelWhitelist.ts new file mode 100644 index 000000000..cfd240ed6 --- /dev/null +++ b/plugins/default/modelWhitelist.ts @@ -0,0 +1,25 @@ +import { + HookEventType, + PluginContext, + PluginHandler, + PluginParameters, +} from '../types'; + +export const handler: PluginHandler = async ( + context: PluginContext, + parameters: PluginParameters, + eventType: HookEventType +) => { + let error = null; + let verdict = false; + + try { + const modelList = parameters.models; + let requestModel = context.request?.json.model; + verdict = modelList.includes(requestModel); + } catch (e) { + error = e as Error; + } + + return { error, verdict }; +}; diff --git a/plugins/index.ts b/plugins/index.ts index 1137e7778..03e7c6785 100644 --- a/plugins/index.ts +++ b/plugins/index.ts @@ -12,6 +12,7 @@ import { handler as defaultcontainsCode } from './default/containsCode'; import { handler as defaultalluppercase } from './default/alluppercase'; import { handler as defaultalllowercase } from './default/alllowercase'; import { handler as defaultendsWith } from './default/endsWith'; +import { handler as defaultmodelWhitelist } from './default/modelWhitelist'; import { handler as portkeymoderateContent } from './portkey/moderateContent'; import { handler as portkeylanguage } from './portkey/language'; import { handler as portkeypii } from './portkey/pii'; @@ -48,6 +49,7 @@ export const plugins = { alluppercase: defaultalluppercase, alllowercase: defaultalllowercase, endsWith: defaultendsWith, + modelWhitelist: defaultmodelWhitelist, }, portkey: { moderateContent: portkeymoderateContent,