Skip to content

Commit

Permalink
Add support for users to specify custom request settings, model and o…
Browse files Browse the repository at this point in the history
…ptionally provider specific (#14535)

* Fix request settings and stop words in HF provider

fixed #14503

Signed-off-by: Jonas Helming <jhelming@eclipsesource.com>
Co-authored-by: Stefan Dirix <sdirix@eclipsesource.com>
  • Loading branch information
JonasHelming and sdirix authored Nov 28, 2024
1 parent 48283ca commit 1301b45
Show file tree
Hide file tree
Showing 19 changed files with 528 additions and 175 deletions.
41 changes: 40 additions & 1 deletion packages/ai-core/src/browser/ai-core-preferences.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { interfaces } from '@theia/core/shared/inversify';
export const AI_CORE_PREFERENCES_TITLE = '✨ AI Features [Experimental]';
export const PREFERENCE_NAME_ENABLE_EXPERIMENTAL = 'ai-features.AiEnable.enableAI';
export const PREFERENCE_NAME_PROMPT_TEMPLATES = 'ai-features.promptTemplates.promptTemplatesFolder';
export const PREFERENCE_NAME_REQUEST_SETTINGS = 'ai-features.modelSettings.requestSettings';

export const aiCorePreferenceSchema: PreferenceSchema = {
type: 'object',
Expand Down Expand Up @@ -55,13 +56,51 @@ export const aiCorePreferenceSchema: PreferenceSchema = {
canSelectMany: false
}
},

},
[PREFERENCE_NAME_REQUEST_SETTINGS]: {
title: 'Custom Request Settings',
markdownDescription: 'Allows specifying custom request settings for multiple models.\n\
Each object represents the configuration for a specific model. The `modelId` field specifies the model ID, `requestSettings` defines model-specific settings.\n\
The `providerId` field is optional and allows you to apply the settings to a specific provider. If not set, the settings will be applied to all providers.\n\
Example providerIds: huggingface, openai, ollama, llamafile.\n\
Refer to [our documentation](https://theia-ide.org/docs/user_ai/#custom-request-settings) for more information.',
type: 'array',
items: {
type: 'object',
properties: {
modelId: {
type: 'string',
description: 'The model id'
},
requestSettings: {
type: 'object',
additionalProperties: true,
description: 'Settings for the specific model ID.',
},
providerId: {
type: 'string',
description: 'The (optional) provider id to apply the settings to. If not set, the settings will be applied to all providers.',
},
},
},
default: [],
}
}
};
export interface AICoreConfiguration {
[PREFERENCE_NAME_ENABLE_EXPERIMENTAL]: boolean | undefined;
[PREFERENCE_NAME_PROMPT_TEMPLATES]: string | undefined;
[PREFERENCE_NAME_REQUEST_SETTINGS]: Array<{
modelId: string;
requestSettings?: { [key: string]: unknown };
providerId?: string;
}> | undefined;
}

export interface RequestSetting {
modelId: string;
requestSettings?: { [key: string]: unknown };
providerId?: string;
}

export const AICorePreferences = Symbol('AICorePreferences');
Expand Down
5 changes: 5 additions & 0 deletions packages/ai-core/src/common/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ export interface LanguageModelMetaData {
readonly family?: string;
readonly maxInputTokens?: number;
readonly maxOutputTokens?: number;
/**
* Default request settings for the language model. These settings can be set by a user preferences.
* Settings in a request will override these default settings.
*/
readonly defaultRequestSettings?: { [key: string]: unknown };
}

export namespace LanguageModelMetaData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import { FrontendApplicationContribution, PreferenceService } from '@theia/core/
import { inject, injectable } from '@theia/core/shared/inversify';
import { HuggingFaceLanguageModelsManager, HuggingFaceModelDescription } from '../common';
import { API_KEY_PREF, MODELS_PREF } from './huggingface-preferences';
import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/browser/ai-core-preferences';

const HUGGINGFACE_PROVIDER_ID = 'huggingface';
@injectable()
export class HuggingFaceFrontendApplicationContribution implements FrontendApplicationContribution {

Expand All @@ -36,31 +38,58 @@ export class HuggingFaceFrontendApplicationContribution implements FrontendAppli
this.manager.setApiKey(apiKey);

const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
this.manager.createOrUpdateLanguageModels(...models.map(createHuggingFaceModelDescription));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createHuggingFaceModelDescription(modelId, requestSettings)));
this.prevModels = [...models];

this.preferenceService.onPreferenceChanged(event => {
if (event.preferenceName === API_KEY_PREF) {
this.manager.setApiKey(event.newValue);
} else if (event.preferenceName === MODELS_PREF) {
const oldModels = new Set(this.prevModels);
const newModels = new Set(event.newValue as string[]);

const modelsToRemove = [...oldModels].filter(model => !newModels.has(model));
const modelsToAdd = [...newModels].filter(model => !oldModels.has(model));

this.manager.removeLanguageModels(...modelsToRemove.map(model => `huggingface/${model}`));
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(createHuggingFaceModelDescription));
this.prevModels = [...event.newValue];
this.handleModelChanges(event.newValue as string[]);
} else if (event.preferenceName === PREFERENCE_NAME_REQUEST_SETTINGS) {
this.handleRequestSettingsChanges(event.newValue as RequestSetting[]);
}
});
});
}
}

function createHuggingFaceModelDescription(modelId: string): HuggingFaceModelDescription {
return {
id: `huggingface/${modelId}`,
model: modelId
};
protected handleModelChanges(newModels: string[]): void {
const oldModels = new Set(this.prevModels);
const updatedModels = new Set(newModels);

const modelsToRemove = [...oldModels].filter(model => !updatedModels.has(model));
const modelsToAdd = [...updatedModels].filter(model => !oldModels.has(model));

this.manager.removeLanguageModels(...modelsToRemove.map(model => `${HUGGINGFACE_PROVIDER_ID}/${model}`));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(modelId => this.createHuggingFaceModelDescription(modelId, requestSettings)));
this.prevModels = newModels;
}

protected handleRequestSettingsChanges(newSettings: RequestSetting[]): void {
const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createHuggingFaceModelDescription(modelId, newSettings)));
}

protected createHuggingFaceModelDescription(
modelId: string,
requestSettings: RequestSetting[]
): HuggingFaceModelDescription {
const id = `${HUGGINGFACE_PROVIDER_ID}/${modelId}`;
const matchingSettings = requestSettings.filter(
setting => (!setting.providerId || setting.providerId === HUGGINGFACE_PROVIDER_ID) && setting.modelId === modelId
);
if (matchingSettings.length > 1) {
console.warn(
`Multiple entries found for modelId "${modelId}". Using the first match and ignoring the rest.`
);
}
const modelRequestSetting = matchingSettings[0];
return {
id: id,
model: modelId,
defaultRequestSettings: modelRequestSetting?.requestSettings
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ export interface HuggingFaceModelDescription {
* The model ID as used by the Hugging Face API.
*/
model: string;
/**
* Default request settings for the Hugging Face model.
*/
defaultRequestSettings?: { [key: string]: unknown };
}

export interface HuggingFaceLanguageModelsManager {
Expand Down
30 changes: 21 additions & 9 deletions packages/ai-hugging-face/src/node/huggingface-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ export class HuggingFaceModel implements LanguageModel {
* @param model the model id as it is used by the Hugging Face API
* @param apiKey function to retrieve the API key for Hugging Face
*/
constructor(public readonly id: string, public model: string, public apiKey: () => string | undefined) {
}
constructor(
public readonly id: string,
public model: string,
public apiKey: () => string | undefined,
public readonly name?: string,
public readonly vendor?: string,
public readonly version?: string,
public readonly family?: string,
public readonly maxInputTokens?: number,
public readonly maxOutputTokens?: number,
public defaultRequestSettings?: Record<string, unknown>
) { }

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const hfInference = this.initializeHfInference();
Expand All @@ -67,15 +77,16 @@ export class HuggingFaceModel implements LanguageModel {
}
}

protected getDefaultSettings(): Record<string, unknown> {
return {
max_new_tokens: 2024,
stop: ['<|endoftext|>', '<eos>']
};
protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
const settings = request.settings ? request.settings : this.defaultRequestSettings;
if (!settings) {
return {};
}
return settings;
}

protected async handleNonStreamingRequest(hfInference: HfInference, request: LanguageModelRequest): Promise<LanguageModelTextResponse> {
const settings = request.settings || this.getDefaultSettings();
const settings = this.getSettings(request);

const response = await hfInference.textGeneration({
model: this.model,
Expand Down Expand Up @@ -104,7 +115,8 @@ export class HuggingFaceModel implements LanguageModel {
request: LanguageModelRequest,
cancellationToken?: CancellationToken
): Promise<LanguageModelResponse> {
const settings = request.settings || this.getDefaultSettings();

const settings = this.getSettings(request);

const stream = hfInference.textGenerationStream({
model: this.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ export class HuggingFaceLanguageModelsManagerImpl implements HuggingFaceLanguage
}
model.model = modelDescription.model;
model.apiKey = apiKeyProvider;
model.defaultRequestSettings = modelDescription.defaultRequestSettings;
} else {
this.languageModelRegistry.addLanguageModels([new HuggingFaceModel(modelDescription.id, modelDescription.model, apiKeyProvider)]);
this.languageModelRegistry.addLanguageModels([
new HuggingFaceModel(
modelDescription.id,
modelDescription.model,
apiKeyProvider,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
modelDescription.defaultRequestSettings
)
]);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import { AICommandHandlerFactory } from '@theia/ai-core/lib/browser/ai-command-h
import { CommandContribution, CommandRegistry, MessageService } from '@theia/core';
import { PreferenceService, QuickInputService } from '@theia/core/lib/browser';
import { inject, injectable } from '@theia/core/shared/inversify';
import { LlamafileEntry, LlamafileManager } from '../common/llamafile-manager';
import { LlamafileManager } from '../common/llamafile-manager';
import { PREFERENCE_LLAMAFILE } from './llamafile-preferences';
import { LlamafileEntry } from './llamafile-frontend-application-contribution';

export const StartLlamafileCommand = {
id: 'llamafile.start',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import { FrontendApplicationContribution, PreferenceService } from '@theia/core/lib/browser';
import { inject, injectable } from '@theia/core/shared/inversify';
import { LlamafileEntry, LlamafileManager } from '../common/llamafile-manager';
import { LlamafileManager, LlamafileModelDescription } from '../common/llamafile-manager';
import { PREFERENCE_LLAMAFILE } from './llamafile-preferences';
import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/browser/ai-core-preferences';

const LLAMAFILE_PROVIDER_ID = 'llamafile';
@injectable()
export class LlamafileFrontendApplicationContribution implements FrontendApplicationContribution {

Expand All @@ -33,27 +35,92 @@ export class LlamafileFrontendApplicationContribution implements FrontendApplica
onStart(): void {
this.preferenceService.ready.then(() => {
const llamafiles = this.preferenceService.get<LlamafileEntry[]>(PREFERENCE_LLAMAFILE, []);
this.llamafileManager.addLanguageModels(llamafiles);
llamafiles.forEach(model => this._knownLlamaFiles.set(model.name, model));
const validLlamafiles = llamafiles.filter(LlamafileEntry.is);

const LlamafileModelDescriptions = this.getLLamaFileModelDescriptions(validLlamafiles);

this.llamafileManager.addLanguageModels(LlamafileModelDescriptions);
validLlamafiles.forEach(model => this._knownLlamaFiles.set(model.name, model));

this.preferenceService.onPreferenceChanged(event => {
if (event.preferenceName === PREFERENCE_LLAMAFILE) {
// only new models which are actual LLamaFileEntries
const newModels = event.newValue.filter((llamafileEntry: unknown) => LlamafileEntry.is(llamafileEntry)) as LlamafileEntry[];
this.handleLlamaFilePreferenceChange(newModels);
} else if (event.preferenceName === PREFERENCE_NAME_REQUEST_SETTINGS) {
this.handleRequestSettingsChange(event.newValue as RequestSetting[]);
}
});
});
}

const llamafilesToAdd = newModels.filter(llamafile =>
!this._knownLlamaFiles.has(llamafile.name) || !LlamafileEntry.equals(this._knownLlamaFiles.get(llamafile.name)!, llamafile));
protected getLLamaFileModelDescriptions(llamafiles: LlamafileEntry[]): LlamafileModelDescription[] {
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
return llamafiles.map(llamafile => {
const matchingSettings = requestSettings.filter(
setting =>
(!setting.providerId || setting.providerId === LLAMAFILE_PROVIDER_ID) &&
setting.modelId === llamafile.name
);
if (matchingSettings.length > 1) {
console.warn(`Multiple entries found for model "${llamafile.name}". Using the first match.`);
}
return {
name: llamafile.name,
uri: llamafile.uri,
port: llamafile.port,
defaultRequestSettings: matchingSettings[0]?.requestSettings
};
});
}

const llamafileIdsToRemove = [...this._knownLlamaFiles.values()].filter(llamafile =>
!newModels.find(a => LlamafileEntry.equals(a, llamafile))).map(a => a.name);
protected handleLlamaFilePreferenceChange(newModels: LlamafileEntry[]): void {
const llamafilesToAdd = newModels.filter(llamafile =>
!this._knownLlamaFiles.has(llamafile.name) ||
!LlamafileEntry.equals(this._knownLlamaFiles.get(llamafile.name)!, llamafile));

this.llamafileManager.removeLanguageModels(llamafileIdsToRemove);
llamafileIdsToRemove.forEach(model => this._knownLlamaFiles.delete(model));
const llamafileIdsToRemove = [...this._knownLlamaFiles.values()].filter(llamafile =>
!newModels.find(newModel => LlamafileEntry.equals(newModel, llamafile)))
.map(llamafile => llamafile.name);

this.llamafileManager.addLanguageModels(llamafilesToAdd);
llamafilesToAdd.forEach(model => this._knownLlamaFiles.set(model.name, model));
}
});
this.llamafileManager.removeLanguageModels(llamafileIdsToRemove);
llamafileIdsToRemove.forEach(id => this._knownLlamaFiles.delete(id));

this.llamafileManager.addLanguageModels(this.getLLamaFileModelDescriptions(llamafilesToAdd));
llamafilesToAdd.forEach(model => this._knownLlamaFiles.set(model.name, model));
}

protected handleRequestSettingsChange(newSettings: RequestSetting[]): void {
const llamafiles = Array.from(this._knownLlamaFiles.values());
const llamafileModelDescriptions = this.getLLamaFileModelDescriptions(llamafiles);
llamafileModelDescriptions.forEach(llamafileModelDescription => {
this.llamafileManager.updateRequestSettings(llamafileModelDescription.name, llamafileModelDescription.defaultRequestSettings);
});
}
}

export interface LlamafileEntry {
name: string;
uri: string;
port: number;
}

namespace LlamafileEntry {
export function equals(a: LlamafileEntry, b: LlamafileEntry): boolean {
return (
a.name === b.name &&
a.uri === b.uri &&
a.port === b.port
);
}

export function is(entry: unknown): entry is LlamafileEntry {
return (
typeof entry === 'object' &&
// eslint-disable-next-line no-null/no-null
entry !== null &&
'name' in entry && typeof (entry as LlamafileEntry).name === 'string' &&
'uri' in entry && typeof (entry as LlamafileEntry).uri === 'string' &&
'port' in entry && typeof (entry as LlamafileEntry).port === 'number'
);
}
}
Loading

0 comments on commit 1301b45

Please sign in to comment.