From 389eedcac41dbe2d82e2736a6d7b1b7e5f01fc8f Mon Sep 17 00:00:00 2001 From: Anirban Kar Date: Tue, 31 Dec 2024 19:22:46 +0530 Subject: [PATCH] fix: better model loading ui feedback and model list update (#954) * fix: better model loading feedback and model list update * added load on providersettings update --- app/components/chat/BaseChat.tsx | 113 ++++++++++++++++---------- app/components/chat/ModelSelector.tsx | 23 ++++-- app/lib/modules/llm/manager.ts | 7 ++ public/icons/Hyperbolic.svg | 3 + 4 files changed, 97 insertions(+), 49 deletions(-) create mode 100644 public/icons/Hyperbolic.svg diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index 1777d734..f285366c 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -3,7 +3,7 @@ * Preventing TS checks with files presented in the video for a better presentation. */ import type { Message } from 'ai'; -import React, { type RefCallback, useEffect, useState } from 'react'; +import React, { type RefCallback, useCallback, useEffect, useState } from 'react'; import { ClientOnly } from 'remix-utils/client-only'; import { Menu } from '~/components/sidebar/Menu.client'; import { IconButton } from '~/components/ui/IconButton'; @@ -31,6 +31,7 @@ import { toast } from 'react-toastify'; import StarterTemplates from './StarterTemplates'; import type { ActionAlert } from '~/types/actions'; import ChatAlert from './ChatAlert'; +import { LLMManager } from '~/lib/modules/llm/manager'; const TEXTAREA_MIN_HEIGHT = 76; @@ -100,26 +101,36 @@ export const BaseChat = React.forwardRef( ref, ) => { const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; - const [apiKeys, setApiKeys] = useState>(() => { - const savedKeys = Cookies.get('apiKeys'); - - if (savedKeys) { - try { - return JSON.parse(savedKeys); - } catch (error) { - console.error('Failed to parse API keys from cookies:', error); - return {}; - } - } - - return {}; - }); + const [apiKeys, setApiKeys] = useState>(getApiKeysFromCookies()); const [modelList, setModelList] = useState(MODEL_LIST); const [isModelSettingsCollapsed, setIsModelSettingsCollapsed] = useState(false); const [isListening, setIsListening] = useState(false); const [recognition, setRecognition] = useState(null); const [transcript, setTranscript] = useState(''); + const [isModelLoading, setIsModelLoading] = useState('all'); + const getProviderSettings = useCallback(() => { + let providerSettings: Record | undefined = undefined; + + try { + const savedProviderSettings = Cookies.get('providers'); + + if (savedProviderSettings) { + const parsedProviderSettings = JSON.parse(savedProviderSettings); + + if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) { + providerSettings = parsedProviderSettings; + } + } + } catch (error) { + console.error('Error loading Provider Settings from cookies:', error); + + // Clear invalid cookie data + Cookies.remove('providers'); + } + + return providerSettings; + }, []); useEffect(() => { console.log(transcript); }, [transcript]); @@ -157,25 +168,7 @@ export const BaseChat = React.forwardRef( }, []); useEffect(() => { - let providerSettings: Record | undefined = undefined; - - try { - const savedProviderSettings = Cookies.get('providers'); - - if (savedProviderSettings) { - const parsedProviderSettings = JSON.parse(savedProviderSettings); - - if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) { - providerSettings = parsedProviderSettings; - } - } - } catch (error) { - console.error('Error loading Provider Settings from cookies:', error); - - // Clear invalid cookie data - Cookies.remove('providers'); - } - + const providerSettings = getProviderSettings(); let parsedApiKeys: Record | undefined = {}; try { @@ -187,12 +180,49 @@ export const BaseChat = React.forwardRef( // Clear invalid cookie data Cookies.remove('apiKeys'); } + setIsModelLoading('all'); + initializeModelList({ apiKeys: parsedApiKeys, providerSettings }) + .then((modelList) => { + console.log('Model List: ', modelList); + setModelList(modelList); + }) + .catch((error) => { + console.error('Error initializing model list:', error); + }) + .finally(() => { + setIsModelLoading(undefined); + }); + }, [providerList]); - initializeModelList({ apiKeys: parsedApiKeys, providerSettings }).then((modelList) => { - console.log('Model List: ', modelList); - setModelList(modelList); - }); - }, [apiKeys]); + const onApiKeysChange = async (providerName: string, apiKey: string) => { + const newApiKeys = { ...apiKeys, [providerName]: apiKey }; + setApiKeys(newApiKeys); + Cookies.set('apiKeys', JSON.stringify(newApiKeys)); + + const provider = LLMManager.getInstance(import.meta.env || process.env || {}).getProvider(providerName); + + if (provider && provider.getDynamicModels) { + setIsModelLoading(providerName); + + try { + const providerSettings = getProviderSettings(); + const staticModels = provider.staticModels; + const dynamicModels = await provider.getDynamicModels( + newApiKeys, + providerSettings, + import.meta.env || process.env || {}, + ); + + setModelList((preModels) => { + const filteredOutPreModels = preModels.filter((x) => x.provider !== providerName); + return [...filteredOutPreModels, ...staticModels, ...dynamicModels]; + }); + } catch (error) { + console.error('Error loading dynamic models:', error); + } + setIsModelLoading(undefined); + } + }; const startListening = () => { if (recognition) { @@ -381,15 +411,14 @@ export const BaseChat = React.forwardRef( setProvider={setProvider} providerList={providerList || (PROVIDER_LIST as ProviderInfo[])} apiKeys={apiKeys} + modelLoading={isModelLoading} /> {(providerList || []).length > 0 && provider && ( { - const newApiKeys = { ...apiKeys, [provider.name]: key }; - setApiKeys(newApiKeys); - Cookies.set('apiKeys', JSON.stringify(newApiKeys)); + onApiKeysChange(provider.name, key); }} /> )} diff --git a/app/components/chat/ModelSelector.tsx b/app/components/chat/ModelSelector.tsx index ec4da63f..521ccac3 100644 --- a/app/components/chat/ModelSelector.tsx +++ b/app/components/chat/ModelSelector.tsx @@ -10,6 +10,7 @@ interface ModelSelectorProps { modelList: ModelInfo[]; providerList: ProviderInfo[]; apiKeys: Record; + modelLoading?: string; } export const ModelSelector = ({ @@ -19,6 +20,7 @@ export const ModelSelector = ({ setProvider, modelList, providerList, + modelLoading, }: ModelSelectorProps) => { // Load enabled providers from cookies @@ -83,14 +85,21 @@ export const ModelSelector = ({ value={model} onChange={(e) => setModel?.(e.target.value)} className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all lg:max-w-[70%]" + disabled={modelLoading === 'all' || modelLoading === provider?.name} > - {[...modelList] - .filter((e) => e.provider == provider?.name && e.name) - .map((modelOption, index) => ( - - ))} + {modelLoading == 'all' || modelLoading == provider?.name ? ( + + ) : ( + [...modelList] + .filter((e) => e.provider == provider?.name && e.name) + .map((modelOption, index) => ( + + )) + )} ); diff --git a/app/lib/modules/llm/manager.ts b/app/lib/modules/llm/manager.ts index 38dc8254..8f46c84f 100644 --- a/app/lib/modules/llm/manager.ts +++ b/app/lib/modules/llm/manager.ts @@ -79,9 +79,16 @@ export class LLMManager { }): Promise { const { apiKeys, providerSettings, serverEnv } = options; + let enabledProviders = Array.from(this._providers.values()).map((p) => p.name); + + if (providerSettings) { + enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled); + } + // Get dynamic models from all providers that support them const dynamicModels = await Promise.all( Array.from(this._providers.values()) + .filter((provider) => enabledProviders.includes(provider.name)) .filter( (provider): provider is BaseProvider & Required> => !!provider.getDynamicModels, diff --git a/public/icons/Hyperbolic.svg b/public/icons/Hyperbolic.svg new file mode 100644 index 00000000..392ed08c --- /dev/null +++ b/public/icons/Hyperbolic.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file