Merge pull request #188 from TommyHolmberg/respect-provider-choice

fix: respect provider choice from UI
This commit is contained in:
Cole Medin 2024-11-09 07:45:06 -06:00 committed by GitHub
commit 936a9c0f69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 45 deletions

View File

@ -26,7 +26,7 @@ const EXAMPLE_PROMPTS = [
const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))] const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]
const ModelSelector = ({ model, setModel, modelList, providerList, provider, setProvider }) => { const ModelSelector = ({ model, setModel, provider, setProvider, modelList, providerList }) => {
return ( return (
<div className="mb-2 flex gap-2"> <div className="mb-2 flex gap-2">
<select <select
@ -80,6 +80,8 @@ interface BaseChatProps {
input?: string; input?: string;
model: string; model: string;
setModel: (model: string) => void; setModel: (model: string) => void;
provider: string;
setProvider: (provider: string) => void;
handleStop?: () => void; handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void; sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void; handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
@ -101,6 +103,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
input = '', input = '',
model, model,
setModel, setModel,
provider,
setProvider,
sendMessage, sendMessage,
handleInputChange, handleInputChange,
enhancePrompt, enhancePrompt,
@ -193,6 +197,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
model={model} model={model}
setModel={setModel} setModel={setModel}
modelList={MODEL_LIST} modelList={MODEL_LIST}
provider={provider}
setProvider={setProvider}
providerList={providerList} providerList={providerList}
provider={provider} provider={provider}
setProvider={setProvider} setProvider={setProvider}

View File

@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat'; import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench'; import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff'; import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL } from '~/utils/constants'; import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings'; import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger'; import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat'; import { BaseChat } from './BaseChat';
@ -75,6 +75,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL); const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const { showChat } = useStore(chatStore); const { showChat } = useStore(chatStore);
@ -188,7 +189,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those * manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here. * aren't relevant here.
*/ */
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` }); append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${diff}\n\n${_input}` });
/** /**
* After sending a new message we reset all modifications since the model * After sending a new message we reset all modifications since the model
@ -196,7 +197,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
*/ */
workbenchStore.resetAllFileModifications(); workbenchStore.resetAllFileModifications();
} else { } else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` }); append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${_input}` });
} }
setInput(''); setInput('');
@ -228,6 +229,8 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
sendMessage={sendMessage} sendMessage={sendMessage}
model={model} model={model}
setModel={setModel} setModel={setModel}
provider={provider}
setProvider={setProvider}
messageRef={messageRef} messageRef={messageRef}
scrollRef={scrollRef} scrollRef={scrollRef}
handleInputChange={handleInputChange} handleInputChange={handleInputChange}

View File

@ -1,7 +1,7 @@
// @ts-nocheck // @ts-nocheck
// Preventing TS checks with files presented in the video for a better presentation. // Preventing TS checks with files presented in the video for a better presentation.
import { modificationsRegex } from '~/utils/diff'; import { modificationsRegex } from '~/utils/diff';
import { MODEL_REGEX } from '~/utils/constants'; import { MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
import { Markdown } from './Markdown'; import { Markdown } from './Markdown';
interface UserMessageProps { interface UserMessageProps {
@ -17,5 +17,5 @@ export function UserMessage({ content }: UserMessageProps) {
} }
function sanitizeUserMessage(content: string) { function sanitizeUserMessage(content: string) {
return content.replace(modificationsRegex, '').replace(MODEL_REGEX, '').trim(); return content.replace(modificationsRegex, '').replace(MODEL_REGEX, 'Using: $1').replace(PROVIDER_REGEX, ' ($1)\n\n').trim();
} }

View File

@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai';
import { getModel } from '~/lib/.server/llm/model'; import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants'; import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts'; import { getSystemPrompt } from './prompts';
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants'; import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
interface ToolResult<Name extends string, Args, Result> { interface ToolResult<Name extends string, Args, Result> {
toolCallId: string; toolCallId: string;
@ -24,18 +24,22 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>; export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;
function extractModelFromMessage(message: Message): { model: string; content: string } { function extractPropertiesFromMessage(message: Message): { model: string; provider: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/; // Extract model
const match = message.content.match(modelRegex); const modelMatch = message.content.match(MODEL_REGEX);
const model = modelMatch ? modelMatch[1] : DEFAULT_MODEL;
if (match) { // Extract provider
const model = match[1]; const providerMatch = message.content.match(PROVIDER_REGEX);
const content = message.content.replace(modelRegex, ''); const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
return { model, content };
}
// Default model if not specified // Remove model and provider lines from content
return { model: DEFAULT_MODEL, content: message.content }; const cleanedContent = message.content
.replace(MODEL_REGEX, '')
.replace(PROVIDER_REGEX, '')
.trim();
return { model, provider, content: cleanedContent };
} }
export function streamText( export function streamText(
@ -45,26 +49,28 @@ export function streamText(
apiKeys?: Record<string, string> apiKeys?: Record<string, string>
) { ) {
let currentModel = DEFAULT_MODEL; let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
const processedMessages = messages.map((message) => { const processedMessages = messages.map((message) => {
if (message.role === 'user') { if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message); const { model, provider, content } = extractPropertiesFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model if (MODEL_LIST.find((m) => m.name === model)) {
currentModel = model;
} }
currentProvider = provider;
return { ...message, content }; return { ...message, content };
} }
return message;
return message; // No changes for non-user messages
}); });
const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;
return _streamText({ return _streamText({
model: getModel(provider, currentModel, env, apiKeys), model: getModel(currentProvider, currentModel, env, apiKeys),
system: getSystemPrompt(), system: getSystemPrompt(),
maxTokens: MAX_TOKENS, maxTokens: MAX_TOKENS,
// headers: {
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
// },
messages: convertToCoreMessages(processedMessages), messages: convertToCoreMessages(processedMessages),
...options, ...options,
}); });

View File

@ -4,6 +4,7 @@ export const WORK_DIR_NAME = 'project';
export const WORK_DIR = `/home/${WORK_DIR_NAME}`; export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications'; export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications';
export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/; export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/;
export const PROVIDER_REGEX = /\[Provider: (.*?)\]\n\n/;
export const DEFAULT_MODEL = 'claude-3-5-sonnet-20240620'; export const DEFAULT_MODEL = 'claude-3-5-sonnet-20240620';
export const DEFAULT_PROVIDER = 'Anthropic'; export const DEFAULT_PROVIDER = 'Anthropic';
@ -20,7 +21,7 @@ const staticModels: ModelInfo[] = [
{ name: 'qwen/qwen-110b-chat', label: 'OpenRouter Qwen 110b Chat (OpenRouter)', provider: 'OpenRouter' }, { name: 'qwen/qwen-110b-chat', label: 'OpenRouter Qwen 110b Chat (OpenRouter)', provider: 'OpenRouter' },
{ name: 'cohere/command', label: 'Cohere Command (OpenRouter)', provider: 'OpenRouter' }, { name: 'cohere/command', label: 'Cohere Command (OpenRouter)', provider: 'OpenRouter' },
{ name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' }, { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' },
{ name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google'}, { name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google' },
{ name: 'llama-3.1-70b-versatile', label: 'Llama 3.1 70b (Groq)', provider: 'Groq' }, { name: 'llama-3.1-70b-versatile', label: 'Llama 3.1 70b (Groq)', provider: 'Groq' },
{ name: 'llama-3.1-8b-instant', label: 'Llama 3.1 8b (Groq)', provider: 'Groq' }, { name: 'llama-3.1-8b-instant', label: 'Llama 3.1 8b (Groq)', provider: 'Groq' },
{ name: 'llama-3.2-11b-vision-preview', label: 'Llama 3.2 11b (Groq)', provider: 'Groq' }, { name: 'llama-3.2-11b-vision-preview', label: 'Llama 3.2 11b (Groq)', provider: 'Groq' },
@ -82,32 +83,32 @@ async function getOllamaModels(): Promise<ModelInfo[]> {
} }
async function getOpenAILikeModels(): Promise<ModelInfo[]> { async function getOpenAILikeModels(): Promise<ModelInfo[]> {
try { try {
const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || ""; const base_url = import.meta.env.OPENAI_LIKE_API_BASE_URL || "";
if (!base_url) { if (!base_url) {
return []; return [];
} }
const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? ""; const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? "";
const response = await fetch(`${base_url}/models`, { const response = await fetch(`${base_url}/models`, {
headers: { headers: {
Authorization: `Bearer ${api_key}`, Authorization: `Bearer ${api_key}`,
} }
}); });
const res = await response.json() as any; const res = await response.json() as any;
return res.data.map((model: any) => ({ return res.data.map((model: any) => ({
name: model.id, name: model.id,
label: model.id, label: model.id,
provider: 'OpenAILike', provider: 'OpenAILike',
})); }));
}catch (e) { } catch (e) {
return [] return []
} }
} }
async function initializeModelList(): Promise<void> { async function initializeModelList(): Promise<void> {
const ollamaModels = await getOllamaModels(); const ollamaModels = await getOllamaModels();
const openAiLikeModels = await getOpenAILikeModels(); const openAiLikeModels = await getOpenAILikeModels();
MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels]; MODEL_LIST = [...ollamaModels, ...openAiLikeModels, ...staticModels];
} }
initializeModelList().then(); initializeModelList().then();
export { getOllamaModels, getOpenAILikeModels, initializeModelList }; export { getOllamaModels, getOpenAILikeModels, initializeModelList };