Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: configure dynamic providers via .env #1108

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 27 additions & 54 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
* Preventing TS checks with files presented in the video for a better presentation.
*/
import type { Message } from 'ai';
import React, { type RefCallback, useCallback, useEffect, useState } from 'react';
import React, { type RefCallback, useEffect, useState } from 'react';
import { ClientOnly } from 'remix-utils/client-only';
import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, PROVIDER_LIST, initializeModelList } from '~/utils/constants';
import { PROVIDER_LIST } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { APIKeyManager, getApiKeysFromCookies } from './APIKeyManager';
Expand All @@ -25,13 +25,13 @@ import GitCloneButton from './GitCloneButton';
import FilePreview from './FilePreview';
import { ModelSelector } from '~/components/chat/ModelSelector';
import { SpeechRecognitionButton } from '~/components/chat/SpeechRecognition';
import type { IProviderSetting, ProviderInfo } from '~/types/model';
import type { ProviderInfo } from '~/types/model';
import { ScreenshotStateManager } from './ScreenshotStateManager';
import { toast } from 'react-toastify';
import StarterTemplates from './StarterTemplates';
import type { ActionAlert } from '~/types/actions';
import ChatAlert from './ChatAlert';
import { LLMManager } from '~/lib/modules/llm/manager';
import type { ModelInfo } from '~/lib/modules/llm/types';

const TEXTAREA_MIN_HEIGHT = 76;

Expand Down Expand Up @@ -102,35 +102,13 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
) => {
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
const [apiKeys, setApiKeys] = useState<Record<string, string>>(getApiKeysFromCookies());
const [modelList, setModelList] = useState(MODEL_LIST);
const [modelList, setModelList] = useState<ModelInfo[]>([]);
const [isModelSettingsCollapsed, setIsModelSettingsCollapsed] = useState(false);
const [isListening, setIsListening] = useState(false);
const [recognition, setRecognition] = useState<SpeechRecognition | null>(null);
const [transcript, setTranscript] = useState('');
const [isModelLoading, setIsModelLoading] = useState<string | undefined>('all');

const getProviderSettings = useCallback(() => {
let providerSettings: Record<string, IProviderSetting> | undefined = undefined;

try {
const savedProviderSettings = Cookies.get('providers');

if (savedProviderSettings) {
const parsedProviderSettings = JSON.parse(savedProviderSettings);

if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) {
providerSettings = parsedProviderSettings;
}
}
} catch (error) {
console.error('Error loading Provider Settings from cookies:', error);

// Clear invalid cookie data
Cookies.remove('providers');
}

return providerSettings;
}, []);
useEffect(() => {
console.log(transcript);
}, [transcript]);
Expand Down Expand Up @@ -169,25 +147,25 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(

useEffect(() => {
if (typeof window !== 'undefined') {
const providerSettings = getProviderSettings();
let parsedApiKeys: Record<string, string> | undefined = {};

try {
parsedApiKeys = getApiKeysFromCookies();
setApiKeys(parsedApiKeys);
} catch (error) {
console.error('Error loading API keys from cookies:', error);

// Clear invalid cookie data
Cookies.remove('apiKeys');
}

setIsModelLoading('all');
initializeModelList({ apiKeys: parsedApiKeys, providerSettings })
.then((modelList) => {
setModelList(modelList);
fetch('/api/models')
.then((response) => response.json())
.then((data) => {
const typedData = data as { modelList: ModelInfo[] };
setModelList(typedData.modelList);
})
.catch((error) => {
console.error('Error initializing model list:', error);
console.error('Error fetching model list:', error);
})
.finally(() => {
setIsModelLoading(undefined);
Expand All @@ -200,29 +178,24 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
setApiKeys(newApiKeys);
Cookies.set('apiKeys', JSON.stringify(newApiKeys));

const provider = LLMManager.getInstance(import.meta.env || process.env || {}).getProvider(providerName);
setIsModelLoading(providerName);

if (provider && provider.getDynamicModels) {
setIsModelLoading(providerName);
let providerModels: ModelInfo[] = [];

try {
const providerSettings = getProviderSettings();
const staticModels = provider.staticModels;
const dynamicModels = await provider.getDynamicModels(
newApiKeys,
providerSettings,
import.meta.env || process.env || {},
);

setModelList((preModels) => {
const filteredOutPreModels = preModels.filter((x) => x.provider !== providerName);
return [...filteredOutPreModels, ...staticModels, ...dynamicModels];
});
} catch (error) {
console.error('Error loading dynamic models:', error);
}
setIsModelLoading(undefined);
try {
const response = await fetch(`/api/models/${encodeURIComponent(providerName)}`);
const data = await response.json();
providerModels = (data as { modelList: ModelInfo[] }).modelList;
} catch (error) {
console.error('Error loading dynamic models for:', providerName, error);
}
mrsimpson marked this conversation as resolved.
Show resolved Hide resolved

// Only update models for the specific provider
setModelList((prevModels) => {
const otherModels = prevModels.filter((model) => model.provider !== providerName);
return [...otherModels, ...providerModels];
});
setIsModelLoading(undefined);
};

const startListening = () => {
Expand Down
33 changes: 33 additions & 0 deletions app/lib/api/cookies.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
export function parseCookies(cookieHeader: string | null) {
const cookies: Record<string, string> = {};

if (!cookieHeader) {
return cookies;
}

// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim());

items.forEach((item) => {
const [name, ...rest] = item.split('=');

if (name && rest.length > 0) {
// Decode the name and value, and join value parts in case it contains '='
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join('=').trim());
cookies[decodedName] = decodedValue;
}
});

return cookies;
}

export function getApiKeysFromCookie(cookieHeader: string | null): Record<string, string> {
const cookies = parseCookies(cookieHeader);
return cookies.apiKeys ? JSON.parse(cookies.apiKeys) : {};
}

export function getProviderSettingsFromCookie(cookieHeader: string | null): Record<string, any> {
const cookies = parseCookies(cookieHeader);
return cookies.providers ? JSON.parse(cookies.providers) : {};
}
2 changes: 1 addition & 1 deletion app/lib/modules/llm/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export class LLMManager {

let enabledProviders = Array.from(this._providers.values()).map((p) => p.name);

if (providerSettings) {
if (providerSettings && Object.keys(providerSettings).length > 0) {
enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled);
}

Expand Down
33 changes: 4 additions & 29 deletions app/routes/api.enhancer.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,13 @@
import { type ActionFunctionArgs } from '@remix-run/cloudflare';

//import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent';
import type { IProviderSetting, ProviderInfo } from '~/types/model';
import type { ProviderInfo } from '~/types/model';
import { getApiKeysFromCookie, getProviderSettingsFromCookie } from '~/lib/api/cookies';

export async function action(args: ActionFunctionArgs) {
return enhancerAction(args);
}

function parseCookies(cookieHeader: string) {
const cookies: any = {};

// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim());

items.forEach((item) => {
const [name, ...rest] = item.split('=');

if (name && rest) {
// Decode the name and value, and join value parts in case it contains '='
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join('=').trim());
cookies[decodedName] = decodedValue;
}
});

return cookies;
}

async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message, model, provider } = await request.json<{
message: string;
Expand All @@ -55,12 +34,8 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
}

const cookieHeader = request.headers.get('Cookie');

// Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const providerSettings: Record<string, IProviderSetting> = JSON.parse(
parseCookies(cookieHeader || '').providers || '{}',
);
const apiKeys = getApiKeysFromCookie(cookieHeader);
const providerSettings = getProviderSettingsFromCookie(cookieHeader);

try {
const result = await streamText({
Expand Down
44 changes: 15 additions & 29 deletions app/routes/api.llmcall.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,24 @@
import { type ActionFunctionArgs } from '@remix-run/cloudflare';

//import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text';
import type { IProviderSetting, ProviderInfo } from '~/types/model';
import { generateText } from 'ai';
import { getModelList, PROVIDER_LIST } from '~/utils/constants';
import { PROVIDER_LIST } from '~/utils/constants';
import { MAX_TOKENS } from '~/lib/.server/llm/constants';
import { LLMManager } from '~/lib/modules/llm/manager';
import type { ModelInfo } from '~/lib/modules/llm/types';
import { getApiKeysFromCookie, getProviderSettingsFromCookie } from '~/lib/api/cookies';

export async function action(args: ActionFunctionArgs) {
return llmCallAction(args);
}

function parseCookies(cookieHeader: string) {
const cookies: any = {};

// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim());

items.forEach((item) => {
const [name, ...rest] = item.split('=');

if (name && rest) {
// Decode the name and value, and join value parts in case it contains '='
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join('=').trim());
cookies[decodedName] = decodedValue;
}
});

return cookies;
async function getModelList(options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
}) {
const llmManager = LLMManager.getInstance(import.meta.env);
return llmManager.updateModelList(options);
}

async function llmCallAction({ context, request }: ActionFunctionArgs) {
Expand Down Expand Up @@ -58,12 +48,8 @@ async function llmCallAction({ context, request }: ActionFunctionArgs) {
}

const cookieHeader = request.headers.get('Cookie');

// Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const providerSettings: Record<string, IProviderSetting> = JSON.parse(
parseCookies(cookieHeader || '').providers || '{}',
);
const apiKeys = getApiKeysFromCookie(cookieHeader);
const providerSettings = getProviderSettingsFromCookie(cookieHeader);

if (streamOutput) {
try {
Expand Down Expand Up @@ -105,8 +91,8 @@ async function llmCallAction({ context, request }: ActionFunctionArgs) {
}
} else {
try {
const MODEL_LIST = await getModelList({ apiKeys, providerSettings, serverEnv: context.cloudflare.env as any });
const modelDetails = MODEL_LIST.find((m) => m.name === model);
const models = await getModelList({ apiKeys, providerSettings, serverEnv: context.cloudflare.env as any });
const modelDetails = models.find((m: ModelInfo) => m.name === model);

if (!modelDetails) {
throw new Error('Model not found');
Expand Down
2 changes: 2 additions & 0 deletions app/routes/api.models.$provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import { loader } from './api.models';
export { loader };
Loading
Loading