diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts index f8167d4dc52d5..a4713d1223968 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts @@ -3,6 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import { coalesce } from 'vs/base/common/arrays'; import { CancellationToken } from 'vs/base/common/cancellation'; import { Emitter, Event } from 'vs/base/common/event'; import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from 'vs/base/common/lifecycle'; @@ -36,7 +37,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { ) { this._proxy = extHostContext.getProxy(ExtHostContext.ExtHostChatProvider); - this._proxy.$updateLanguageModels({ added: _chatProviderService.getLanguageModelIds() }); + this._proxy.$updateLanguageModels({ added: coalesce(_chatProviderService.getLanguageModelIds().map(id => _chatProviderService.lookupLanguageModel(id))) }); this._store.add(_chatProviderService.onDidChangeLanguageModels(this._proxy.$updateLanguageModels, this._proxy)); } @@ -91,7 +92,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { await Promise.race([ activate, - Event.toPromise(Event.filter(this._chatProviderService.onDidChangeLanguageModels, e => Boolean(e.added?.includes(providerId)))) + Event.toPromise(Event.filter(this._chatProviderService.onDidChangeLanguageModels, e => Boolean(e.added?.some(value => value.identifier === providerId)))) ]); return this._chatProviderService.lookupLanguageModel(providerId); diff --git a/src/vs/workbench/api/common/extHost.api.impl.ts b/src/vs/workbench/api/common/extHost.api.impl.ts index 8a99d63cb4c30..da678cb6efb58 100644 --- a/src/vs/workbench/api/common/extHost.api.impl.ts +++ b/src/vs/workbench/api/common/extHost.api.impl.ts @@ -25,11 +25,11 @@ import { CandidatePortSource, ExtHostContext, ExtHostLogLevelServiceShape, MainC import { ExtHostRelatedInformation } from 'vs/workbench/api/common/extHostAiRelatedInformation'; import { ExtHostApiCommands } from 'vs/workbench/api/common/extHostApiCommands'; import { IExtHostApiDeprecationService } from 'vs/workbench/api/common/extHostApiDeprecationService'; -import { ExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication'; +import { IExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication'; import { ExtHostBulkEdits } from 'vs/workbench/api/common/extHostBulkEdits'; import { ExtHostChat } from 'vs/workbench/api/common/extHostChat'; import { ExtHostChatAgents2 } from 'vs/workbench/api/common/extHostChatAgents2'; -import { ExtHostLanguageModels } from 'vs/workbench/api/common/extHostLanguageModels'; +import { IExtHostLanguageModels } from 'vs/workbench/api/common/extHostLanguageModels'; import { ExtHostChatVariables } from 'vs/workbench/api/common/extHostChatVariables'; import { ExtHostClipboard } from 'vs/workbench/api/common/extHostClipboard'; import { ExtHostEditorInsets } from 'vs/workbench/api/common/extHostCodeInsets'; @@ -143,6 +143,8 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I const extHostSecretState = accessor.get(IExtHostSecretState); const extHostEditorTabs = accessor.get(IExtHostEditorTabs); const extHostManagedSockets = accessor.get(IExtHostManagedSockets); + const extHostAuthentication = accessor.get(IExtHostAuthentication); + const extHostLanguageModels = accessor.get(IExtHostLanguageModels); // register addressable instances rpcProtocol.set(ExtHostContext.ExtHostFileSystemInfo, extHostFileSystemInfo); @@ -157,6 +159,8 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I rpcProtocol.set(ExtHostContext.ExtHostTelemetry, extHostTelemetry); rpcProtocol.set(ExtHostContext.ExtHostEditorTabs, extHostEditorTabs); rpcProtocol.set(ExtHostContext.ExtHostManagedSockets, extHostManagedSockets); + rpcProtocol.set(ExtHostContext.ExtHostAuthentication, extHostAuthentication); + rpcProtocol.set(ExtHostContext.ExtHostChatProvider, extHostLanguageModels); // automatically create and register addressable instances const extHostDecorations = rpcProtocol.set(ExtHostContext.ExtHostDecorations, accessor.get(IExtHostDecorations)); @@ -196,7 +200,6 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I const extHostProgress = rpcProtocol.set(ExtHostContext.ExtHostProgress, new ExtHostProgress(rpcProtocol.getProxy(MainContext.MainThreadProgress))); const extHostLabelService = rpcProtocol.set(ExtHostContext.ExtHostLabelService, new ExtHostLabelService(rpcProtocol)); const extHostTheming = rpcProtocol.set(ExtHostContext.ExtHostTheming, new ExtHostTheming(rpcProtocol)); - const extHostAuthentication = rpcProtocol.set(ExtHostContext.ExtHostAuthentication, new ExtHostAuthentication(rpcProtocol)); const extHostTimeline = rpcProtocol.set(ExtHostContext.ExtHostTimeline, new ExtHostTimeline(rpcProtocol, extHostCommands)); const extHostWebviews = rpcProtocol.set(ExtHostContext.ExtHostWebviews, new ExtHostWebviews(rpcProtocol, initData.remote, extHostWorkspace, extHostLogService, extHostApiDeprecation)); const extHostWebviewPanels = rpcProtocol.set(ExtHostContext.ExtHostWebviewPanels, new ExtHostWebviewPanels(rpcProtocol, extHostWebviews, extHostWorkspace)); @@ -207,7 +210,6 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I const extHostProfileContentHandlers = rpcProtocol.set(ExtHostContext.ExtHostProfileContentHandlers, new ExtHostProfileContentHandlers(rpcProtocol)); rpcProtocol.set(ExtHostContext.ExtHostInteractive, new ExtHostInteractive(rpcProtocol, extHostNotebook, extHostDocumentsAndEditors, extHostCommands, extHostLogService)); const extHostInteractiveEditor = rpcProtocol.set(ExtHostContext.ExtHostInlineChat, new ExtHostInteractiveEditor(rpcProtocol, extHostCommands, extHostDocuments, extHostLogService)); - const extHostChatProvider = rpcProtocol.set(ExtHostContext.ExtHostChatProvider, new ExtHostLanguageModels(rpcProtocol, extHostLogService, extHostAuthentication)); const extHostChatAgents2 = rpcProtocol.set(ExtHostContext.ExtHostChatAgents2, new ExtHostChatAgents2(rpcProtocol, extHostLogService, extHostCommands)); const extHostChatVariables = rpcProtocol.set(ExtHostContext.ExtHostChatVariables, new ExtHostChatVariables(rpcProtocol)); const extHostChat = rpcProtocol.set(ExtHostContext.ExtHostChat, new ExtHostChat(rpcProtocol)); @@ -1413,7 +1415,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I const chat: typeof vscode.chat = { registerChatResponseProvider(id: string, provider: vscode.ChatResponseProvider, metadata: vscode.ChatResponseProviderMetadata) { checkProposedApiEnabled(extension, 'chatProvider'); - return extHostChatProvider.registerLanguageModel(extension, id, provider, metadata); + return extHostLanguageModels.registerLanguageModel(extension, id, provider, metadata); }, registerChatVariableResolver(name: string, description: string, resolver: vscode.ChatVariableResolver) { checkProposedApiEnabled(extension, 'chatVariableResolver'); @@ -1433,15 +1435,15 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I const lm: typeof vscode.lm = { get languageModels() { checkProposedApiEnabled(extension, 'languageModels'); - return extHostChatProvider.getLanguageModelIds(); + return extHostLanguageModels.getLanguageModelIds(); }, onDidChangeLanguageModels: (listener, thisArgs?, disposables?) => { checkProposedApiEnabled(extension, 'languageModels'); - return extHostChatProvider.onDidChangeProviders(listener, thisArgs, disposables); + return extHostLanguageModels.onDidChangeProviders(listener, thisArgs, disposables); }, sendChatRequest(languageModel: string, messages: vscode.LanguageModelChatMessage[], options: vscode.LanguageModelChatRequestOptions, token: vscode.CancellationToken) { checkProposedApiEnabled(extension, 'languageModels'); - return extHostChatProvider.sendChatRequest(extension, languageModel, messages, options, token); + return extHostLanguageModels.sendChatRequest(extension, languageModel, messages, options, token); } }; diff --git a/src/vs/workbench/api/common/extHost.common.services.ts b/src/vs/workbench/api/common/extHost.common.services.ts index faf45b596a757..d48b0d4739127 100644 --- a/src/vs/workbench/api/common/extHost.common.services.ts +++ b/src/vs/workbench/api/common/extHost.common.services.ts @@ -28,11 +28,15 @@ import { ILoggerService } from 'vs/platform/log/common/log'; import { ExtHostVariableResolverProviderService, IExtHostVariableResolverProvider } from 'vs/workbench/api/common/extHostVariableResolverService'; import { ExtHostLocalizationService, IExtHostLocalizationService } from 'vs/workbench/api/common/extHostLocalizationService'; import { ExtHostManagedSockets, IExtHostManagedSockets } from 'vs/workbench/api/common/extHostManagedSockets'; +import { ExtHostAuthentication, IExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication'; +import { ExtHostLanguageModels, IExtHostLanguageModels } from 'vs/workbench/api/common/extHostLanguageModels'; registerSingleton(IExtHostLocalizationService, ExtHostLocalizationService, InstantiationType.Delayed); registerSingleton(ILoggerService, ExtHostLoggerService, InstantiationType.Delayed); registerSingleton(IExtHostApiDeprecationService, ExtHostApiDeprecationService, InstantiationType.Delayed); registerSingleton(IExtHostCommands, ExtHostCommands, InstantiationType.Eager); +registerSingleton(IExtHostAuthentication, ExtHostAuthentication, InstantiationType.Eager); +registerSingleton(IExtHostLanguageModels, ExtHostLanguageModels, InstantiationType.Eager); registerSingleton(IExtHostConfiguration, ExtHostConfiguration, InstantiationType.Eager); registerSingleton(IExtHostConsumerFileSystem, ExtHostConsumerFileSystem, InstantiationType.Eager); registerSingleton(IExtHostDebugService, WorkerExtHostDebugService, InstantiationType.Eager); diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index d2ae7af630d50..f57a727dd3368 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -1189,7 +1189,7 @@ export interface MainThreadLanguageModelsShape extends IDisposable { } export interface ExtHostLanguageModelsShape { - $updateLanguageModels(data: { added?: string[]; removed?: string[] }): void; + $updateLanguageModels(data: { added?: ILanguageModelChatMetadata[]; removed?: string[] }): void; $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void; $provideLanguageModelResponse(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise; $handleResponseFragment(requestId: number, chunk: IChatResponseFragment): Promise; diff --git a/src/vs/workbench/api/common/extHostAuthentication.ts b/src/vs/workbench/api/common/extHostAuthentication.ts index 84ddbd6fa5508..1c562edf76a19 100644 --- a/src/vs/workbench/api/common/extHostAuthentication.ts +++ b/src/vs/workbench/api/common/extHostAuthentication.ts @@ -5,10 +5,15 @@ import type * as vscode from 'vscode'; import { Emitter, Event } from 'vs/base/common/event'; -import { IMainContext, MainContext, MainThreadAuthenticationShape, ExtHostAuthenticationShape } from 'vs/workbench/api/common/extHost.protocol'; +import { MainContext, MainThreadAuthenticationShape, ExtHostAuthenticationShape } from 'vs/workbench/api/common/extHost.protocol'; import { Disposable } from 'vs/workbench/api/common/extHostTypes'; import { IExtensionDescription, ExtensionIdentifier } from 'vs/platform/extensions/common/extensions'; import { INTERNAL_AUTH_PROVIDER_PREFIX } from 'vs/workbench/services/authentication/common/authentication'; +import { createDecorator } from 'vs/platform/instantiation/common/instantiation'; +import { IExtHostRpcService } from 'vs/workbench/api/common/extHostRpcService'; + +export interface IExtHostAuthentication extends ExtHostAuthentication { } +export const IExtHostAuthentication = createDecorator('IExtHostAuthentication'); interface ProviderWithMetadata { label: string; @@ -17,6 +22,9 @@ interface ProviderWithMetadata { } export class ExtHostAuthentication implements ExtHostAuthenticationShape { + + declare _serviceBrand: undefined; + private _proxy: MainThreadAuthenticationShape; private _authenticationProviders: Map = new Map(); @@ -26,8 +34,10 @@ export class ExtHostAuthentication implements ExtHostAuthenticationShape { private _getSessionTaskSingler = new TaskSingler(); private _getSessionsTaskSingler = new TaskSingler>(); - constructor(mainContext: IMainContext) { - this._proxy = mainContext.getProxy(MainContext.MainThreadAuthentication); + constructor( + @IExtHostRpcService extHostRpc: IExtHostRpcService + ) { + this._proxy = extHostRpc.getProxy(MainContext.MainThreadAuthentication); } async getSession(requestingExtension: IExtensionDescription, providerId: string, scopes: readonly string[], options: vscode.AuthenticationGetSessionOptions & ({ createIfNone: true } | { forceNewSession: true } | { forceNewSession: vscode.AuthenticationForceNewSessionOptions })): Promise; diff --git a/src/vs/workbench/api/common/extHostExtensionService.ts b/src/vs/workbench/api/common/extHostExtensionService.ts index 91f910aa4c8fd..56fa013289380 100644 --- a/src/vs/workbench/api/common/extHostExtensionService.ts +++ b/src/vs/workbench/api/common/extHostExtensionService.ts @@ -36,6 +36,7 @@ import { IExtHostRpcService } from 'vs/workbench/api/common/extHostRpcService'; import { ServiceCollection } from 'vs/platform/instantiation/common/serviceCollection'; import { IExtHostTunnelService } from 'vs/workbench/api/common/extHostTunnelService'; import { IExtHostTerminalService } from 'vs/workbench/api/common/extHostTerminalService'; +import { IExtHostLanguageModels } from 'vs/workbench/api/common/extHostLanguageModels'; import { Emitter, Event } from 'vs/base/common/event'; import { IExtensionActivationHost, checkActivateWorkspaceContainsExtension } from 'vs/workbench/services/extensions/common/workspaceContains'; import { ExtHostSecretState, IExtHostSecretState } from 'vs/workbench/api/common/extHostSecretState'; @@ -136,6 +137,7 @@ export abstract class AbstractExtHostExtensionService extends Disposable impleme @IExtHostTerminalService extHostTerminalService: IExtHostTerminalService, @IExtHostLocalizationService extHostLocalizationService: IExtHostLocalizationService, @IExtHostManagedSockets private readonly _extHostManagedSockets: IExtHostManagedSockets, + @IExtHostLanguageModels private readonly _extHostLanguageModels: IExtHostLanguageModels, ) { super(); this._hostUtils = hostUtils; @@ -489,6 +491,7 @@ export abstract class AbstractExtHostExtensionService extends Disposable impleme private _loadExtensionContext(extensionDescription: IExtensionDescription): Promise { + const lanuageModelAccessInformation = this._extHostLanguageModels.createLanguageModelAccessInformation(extensionDescription); const globalState = new ExtensionGlobalMemento(extensionDescription, this._storage); const workspaceState = new ExtensionMemento(extensionDescription.identifier.value, false, this._storage); const secrets = new ExtensionSecrets(extensionDescription, this._secretState); @@ -517,6 +520,7 @@ export abstract class AbstractExtHostExtensionService extends Disposable impleme workspaceState, secrets, subscriptions: [], + get languageModelAccessInformation() { return lanuageModelAccessInformation; }, get extensionUri() { return extensionDescription.extensionLocation; }, get extensionPath() { return extensionDescription.extensionLocation.fsPath; }, asAbsolutePath(relativePath: string) { return path.join(extensionDescription.extensionLocation.fsPath, relativePath); }, diff --git a/src/vs/workbench/api/common/extHostLanguageModels.ts b/src/vs/workbench/api/common/extHostLanguageModels.ts index ea10161e75f36..8221534013c05 100644 --- a/src/vs/workbench/api/common/extHostLanguageModels.ts +++ b/src/vs/workbench/api/common/extHostLanguageModels.ts @@ -5,8 +5,7 @@ import { CancellationToken } from 'vs/base/common/cancellation'; import { IDisposable, toDisposable } from 'vs/base/common/lifecycle'; -import { ILogService } from 'vs/platform/log/common/log'; -import { ExtHostLanguageModelsShape, IMainContext, MainContext, MainThreadLanguageModelsShape } from 'vs/workbench/api/common/extHost.protocol'; +import { ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from 'vs/workbench/api/common/extHost.protocol'; import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; import { LanguageModelError } from 'vs/workbench/api/common/extHostTypes'; import type * as vscode from 'vscode'; @@ -14,13 +13,21 @@ import { Progress } from 'vs/platform/progress/common/progress'; import { IChatMessage, IChatResponseFragment, ILanguageModelChatMetadata } from 'vs/workbench/contrib/chat/common/languageModels'; import { ExtensionIdentifier, ExtensionIdentifierMap, ExtensionIdentifierSet, IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { AsyncIterableSource, Barrier } from 'vs/base/common/async'; -import { Emitter } from 'vs/base/common/event'; -import { ExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication'; +import { Emitter, Event } from 'vs/base/common/event'; import { localize } from 'vs/nls'; import { INTERNAL_AUTH_PROVIDER_PREFIX } from 'vs/workbench/services/authentication/common/authentication'; import { CancellationError } from 'vs/base/common/errors'; +import { createDecorator } from 'vs/platform/instantiation/common/instantiation'; +import { IExtHostRpcService } from 'vs/workbench/api/common/extHostRpcService'; +import { IExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication'; +import { ILogService } from 'vs/platform/log/common/log'; + +export interface IExtHostLanguageModels extends ExtHostLanguageModels { } + +export const IExtHostLanguageModels = createDecorator('IExtHostLanguageModels'); type LanguageModelData = { + readonly languageModelId: string; readonly extension: ExtensionIdentifier; readonly provider: vscode.ChatResponseProvider; }; @@ -106,6 +113,8 @@ class LanguageModelResponse { export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { + declare _serviceBrand: undefined; + private static _idPool = 1; private readonly _proxy: MainThreadLanguageModelsShape; @@ -114,17 +123,16 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { readonly onDidChangeProviders = this._onDidChangeProviders.event; private readonly _languageModels = new Map(); - private readonly _languageModelIds = new Set(); // these are ALL models, not just the one in this EH + private readonly _allLanguageModelData = new Map(); // these are ALL models, not just the one in this EH private readonly _modelAccessList = new ExtensionIdentifierMap(); private readonly _pendingRequest = new Map(); - constructor( - mainContext: IMainContext, - private readonly _logService: ILogService, - private readonly _extHostAuthentication: ExtHostAuthentication, + @IExtHostRpcService extHostRpc: IExtHostRpcService, + @ILogService private readonly _logService: ILogService, + @IExtHostAuthentication private readonly _extHostAuthentication: IExtHostAuthentication, ) { - this._proxy = mainContext.getProxy(MainContext.MainThreadLanguageModels); + this._proxy = extHostRpc.getProxy(MainContext.MainThreadLanguageModels); } dispose(): void { @@ -135,7 +143,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { registerLanguageModel(extension: IExtensionDescription, identifier: string, provider: vscode.ChatResponseProvider, metadata: vscode.ChatResponseProviderMetadata): IDisposable { const handle = ExtHostLanguageModels._idPool++; - this._languageModels.set(handle, { extension: extension.identifier, provider }); + this._languageModels.set(handle, { extension: extension.identifier, provider, languageModelId: identifier }); let auth; if (metadata.auth) { auth = { @@ -145,6 +153,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { } this._proxy.$registerLanguageModelProvider(handle, identifier, { extension: extension.identifier, + identifier: identifier, model: metadata.name ?? '', auth }); @@ -173,19 +182,19 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { //#region --- making request - $updateLanguageModels(data: { added?: string[] | undefined; removed?: string[] | undefined }): void { + $updateLanguageModels(data: { added?: ILanguageModelChatMetadata[] | undefined; removed?: string[] | undefined }): void { const added: string[] = []; const removed: string[] = []; if (data.added) { - for (const id of data.added) { - this._languageModelIds.add(id); - added.push(id); + for (const metadata of data.added) { + this._allLanguageModelData.set(metadata.identifier, metadata); + added.push(metadata.model); } } if (data.removed) { for (const id of data.removed) { // clean up - this._languageModelIds.delete(id); + this._allLanguageModelData.delete(id); removed.push(id); // cancel pending requests for this model @@ -202,10 +211,13 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { added: Object.freeze(added), removed: Object.freeze(removed) })); + + // TODO@jrieken@TylerLeonhardt - this is a temporary hack to populate the auth providers + data.added?.forEach(this._fakeAuthPopulate, this); } getLanguageModelIds(): string[] { - return Array.from(this._languageModelIds); + return Array.from(this._allLanguageModelData.keys()); } $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void { @@ -232,7 +244,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { const from = extension.identifier; const metadata = await this._proxy.$prepareChatAccess(from, languageModelId, options.justification); - if (!metadata || !this._languageModelIds.has(languageModelId)) { + if (!metadata || !this._allLanguageModelData.has(languageModelId)) { throw LanguageModelError.NotFound(`Language model '${languageModelId}' is unknown.`); } @@ -323,4 +335,50 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { // And we're asking from a different extension && !ExtensionIdentifier.equals(toMetadata.extension, from); } + + private async _fakeAuthPopulate(metadata: ILanguageModelChatMetadata): Promise { + + for (const from of this._languageAccessInformationExtensions) { + try { + await this._getAuthAccess(from, { identifier: metadata.extension, displayName: '' }, undefined, true); + } catch (err) { + this._logService.error('Fake Auth request failed'); + this._logService.error(err); + } + } + + } + + private readonly _languageAccessInformationExtensions = new Set>(); + + createLanguageModelAccessInformation(from: Readonly): vscode.LanguageModelAccessInformation { + + this._languageAccessInformationExtensions.add(from); + + const that = this; + const _onDidChangeAccess = Event.signal(Event.filter(this._onDidChangeModelAccess.event, e => ExtensionIdentifier.equals(e.from, from.identifier))); + const _onDidAddRemove = Event.signal(this._onDidChangeProviders.event); + + return { + get onDidChange() { + return Event.any(_onDidChangeAccess, _onDidAddRemove); + }, + canSendRequest(languageModelId: string): boolean | undefined { + + const data = that._allLanguageModelData.get(languageModelId); + if (!data) { + return undefined; + } + if (!that._isUsingAuth(from.identifier, data)) { + return true; + } + + const list = that._modelAccessList.get(from.identifier); + if (!list) { + return undefined; + } + return list.has(data.extension); + } + }; + } } diff --git a/src/vs/workbench/contrib/chat/common/languageModels.ts b/src/vs/workbench/contrib/chat/common/languageModels.ts index 7f93594aee3f3..7304d00262887 100644 --- a/src/vs/workbench/contrib/chat/common/languageModels.ts +++ b/src/vs/workbench/contrib/chat/common/languageModels.ts @@ -28,6 +28,7 @@ export interface IChatResponseFragment { export interface ILanguageModelChatMetadata { readonly extension: ExtensionIdentifier; + readonly identifier: string; readonly model: string; readonly description?: string; readonly auth?: { @@ -47,7 +48,7 @@ export interface ILanguageModelsService { readonly _serviceBrand: undefined; - onDidChangeLanguageModels: Event<{ added?: string[]; removed?: string[] }>; + onDidChangeLanguageModels: Event<{ added?: ILanguageModelChatMetadata[]; removed?: string[] }>; getLanguageModelIds(): string[]; @@ -63,8 +64,8 @@ export class LanguageModelsService implements ILanguageModelsService { private readonly _providers: Map = new Map(); - private readonly _onDidChangeProviders = new Emitter<{ added?: string[]; removed?: string[] }>(); - readonly onDidChangeLanguageModels: Event<{ added?: string[]; removed?: string[] }> = this._onDidChangeProviders.event; + private readonly _onDidChangeProviders = new Emitter<{ added?: ILanguageModelChatMetadata[]; removed?: string[] }>(); + readonly onDidChangeLanguageModels: Event<{ added?: ILanguageModelChatMetadata[]; removed?: string[] }> = this._onDidChangeProviders.event; dispose() { this._onDidChangeProviders.dispose(); @@ -84,7 +85,7 @@ export class LanguageModelsService implements ILanguageModelsService { throw new Error(`Chat response provider with identifier ${identifier} is already registered.`); } this._providers.set(identifier, provider); - this._onDidChangeProviders.fire({ added: [identifier] }); + this._onDidChangeProviders.fire({ added: [provider.metadata] }); return toDisposable(() => { if (this._providers.delete(identifier)) { this._onDidChangeProviders.fire({ removed: [identifier] }); diff --git a/src/vscode-dts/vscode.proposed.languageModels.d.ts b/src/vscode-dts/vscode.proposed.languageModels.d.ts index 775e4d5e1dc46..8732b0e67e095 100644 --- a/src/vscode-dts/vscode.proposed.languageModels.d.ts +++ b/src/vscode-dts/vscode.proposed.languageModels.d.ts @@ -149,7 +149,7 @@ declare module 'vscode' { /** * Do not show the consent UI if the user has not yet granted access to the language model but fail the request instead. */ - // TODO@API refine/define + // TODO@API Revisit this, how do you do the first request? silent?: boolean; /** @@ -178,10 +178,10 @@ declare module 'vscode' { * @param messages An array of message instances. * @param options Objects that control the request. * @param token A cancellation token which controls the request. See {@link CancellationTokenSource} for how to create one. - * @returns A thenable that resolves to a {@link LanguageModelChatResponse}. The promise will reject when making the request failed. + * @returns A thenable that resolves to a {@link LanguageModelChatResponse}. The promise will reject when the request couldn't be made. */ // TODO@API refine doc - // TODO@API ExtensionContext#permission#languageModels: { languageModel: string: LanguageModelAccessInformation} + // TODO@API ✅ ExtensionContext#permission#languageModels: { languageModel: string: LanguageModelAccessInformation} // TODO@API ✅ define specific error types? // TODO@API ✅ NAME: sendChatRequest, fetchChatResponse, makeChatRequest, chat, chatRequest sendChatRequest // TODO@API ✅ NAME: LanguageModelChatXYZMessage @@ -201,11 +201,39 @@ declare module 'vscode' { export const onDidChangeLanguageModels: Event; } - // export function chatRequest2(languageModel: string, callback: (request: LanguageModelRequest) => R): Thenable; + /** + * Represents extension specific information about the access to language models. + */ + export interface LanguageModelAccessInformation { + + /** + * An event that fires when access information changes. + */ + onDidChange: Event; - // interface LanguageModelRequest { - // readonly quota: any; - // readonly permissions: any; - // makeRequest(messages: LanguageModelChatMessage[], options: { [name: string]: any }, token: CancellationToken): LanguageModelChatResponse; - // } + /** + * Checks if a request can be made to a language model. + * + * *Note* that calling this function will not trigger a consent UI but just checks. + * + * @param languageModelId A language model identifier. + * @return `true` if a request can be made, `false` if not, `undefined` if the language + * model does not exist or consent hasn't been asked for. + */ + canSendRequest(languageModelId: string): boolean | undefined; + + // TODO@API SYNC or ASYNC? + // TODO@API future + // retrieveQuota(languageModelId: string): { remaining: number; resets: Date }; + } + + export interface ExtensionContext { + + /** + * An object that keeps information about how this extension can use language models. + * + * @see {@link lm.sendChatRequest} + */ + readonly languageModelAccessInformation: LanguageModelAccessInformation; + } }