Various bug fixes around model/provider selection

This commit is contained in:
eduardruzga 2024-11-13 22:22:25 +02:00
parent e55fb57138
commit 9396734dea
2 changed files with 18 additions and 9 deletions

View File

@ -7,7 +7,7 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton'; import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client'; import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames'; import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER, PROVIDER_LIST, ProviderInfo } from '~/utils/constants'; import { MODEL_LIST, DEFAULT_PROVIDER, PROVIDER_LIST, ProviderInfo, initializeModelList } from '~/utils/constants';
import { Messages } from './Messages.client'; import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client'; import { SendButton } from './SendButton.client';
import { useState } from 'react'; import { useState } from 'react';
@ -45,8 +45,10 @@ const ModelSelector = ({ model, setModel, provider, setProvider, modelList, prov
))} ))}
</select> </select>
<select <select
key={provider?.name}
value={model} value={model}
onChange={(e) => setModel(e.target.value)} onChange={(e) => setModel(e.target.value)}
style={{maxWidth: "70%"}}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all" className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
> >
{[...modelList] {[...modelList]
@ -111,6 +113,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
console.log(provider); console.log(provider);
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
const [apiKeys, setApiKeys] = useState<Record<string, string>>({}); const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
const [modelList, setModelList] = useState(MODEL_LIST);
useEffect(() => { useEffect(() => {
// Load API keys from cookies on component mount // Load API keys from cookies on component mount
@ -127,6 +131,10 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
// Clear invalid cookie data // Clear invalid cookie data
Cookies.remove('apiKeys'); Cookies.remove('apiKeys');
} }
initializeModelList().then(modelList => {
setModelList(modelList);
});
}, []); }, []);
const updateApiKey = (provider: string, key: string) => { const updateApiKey = (provider: string, key: string) => {
@ -190,12 +198,13 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
})} })}
> >
<ModelSelector <ModelSelector
key={provider?.name + ':' + modelList.length}
model={model} model={model}
setModel={setModel} setModel={setModel}
modelList={MODEL_LIST} modelList={modelList}
provider={provider} provider={provider}
setProvider={setProvider} setProvider={setProvider}
providerList={providerList} providerList={PROVIDER_LIST}
/> />
{provider && {provider &&
<APIKeyManager <APIKeyManager

View File

@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat'; import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench'; import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff'; import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants'; import { DEFAULT_MODEL, DEFAULT_PROVIDER, PROVIDER_LIST, ProviderInfo } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings'; import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger'; import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat'; import { BaseChat } from './BaseChat';
@ -80,7 +80,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
}); });
const [provider, setProvider] = useState(() => { const [provider, setProvider] = useState(() => {
const savedProvider = Cookies.get('selectedProvider'); const savedProvider = Cookies.get('selectedProvider');
return savedProvider || DEFAULT_PROVIDER; return PROVIDER_LIST.find(p => p.name === savedProvider) || DEFAULT_PROVIDER;
}); });
const { showChat } = useStore(chatStore); const { showChat } = useStore(chatStore);
@ -96,7 +96,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
}, },
onError: (error) => { onError: (error) => {
logger.error('Request failed\n\n', error); logger.error('Request failed\n\n', error);
toast.error('There was an error processing your request'); toast.error('There was an error processing your request: ' + (error.message ? error.message : "No details were returned"));
}, },
onFinish: () => { onFinish: () => {
logger.debug('Finished streaming'); logger.debug('Finished streaming');
@ -227,9 +227,9 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
Cookies.set('selectedModel', newModel, { expires: 30 }); Cookies.set('selectedModel', newModel, { expires: 30 });
}; };
const handleProviderChange = (newProvider: string) => { const handleProviderChange = (newProvider: ProviderInfo) => {
setProvider(newProvider); setProvider(newProvider);
Cookies.set('selectedProvider', newProvider, { expires: 30 }); Cookies.set('selectedProvider', newProvider.name, { expires: 30 });
}; };
return ( return (