mirror of
https://github.com/stackblitz/bolt.new
synced 2025-03-12 14:58:30 +00:00
Merge pull request #188 from TommyHolmberg/respect-provider-choice
fix: respect provider choice from UI
This commit is contained in:
commit
936a9c0f69
@ -26,7 +26,7 @@ const EXAMPLE_PROMPTS = [
|
|||||||
|
|
||||||
const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]
|
const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]
|
||||||
|
|
||||||
const ModelSelector = ({ model, setModel, modelList, providerList, provider, setProvider }) => {
|
const ModelSelector = ({ model, setModel, provider, setProvider, modelList, providerList }) => {
|
||||||
return (
|
return (
|
||||||
<div className="mb-2 flex gap-2">
|
<div className="mb-2 flex gap-2">
|
||||||
<select
|
<select
|
||||||
@ -80,6 +80,8 @@ interface BaseChatProps {
|
|||||||
input?: string;
|
input?: string;
|
||||||
model: string;
|
model: string;
|
||||||
setModel: (model: string) => void;
|
setModel: (model: string) => void;
|
||||||
|
provider: string;
|
||||||
|
setProvider: (provider: string) => void;
|
||||||
handleStop?: () => void;
|
handleStop?: () => void;
|
||||||
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
|
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
|
||||||
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
|
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
|
||||||
@ -101,6 +103,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
|
|||||||
input = '',
|
input = '',
|
||||||
model,
|
model,
|
||||||
setModel,
|
setModel,
|
||||||
|
provider,
|
||||||
|
setProvider,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
handleInputChange,
|
handleInputChange,
|
||||||
enhancePrompt,
|
enhancePrompt,
|
||||||
@ -193,6 +197,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
|
|||||||
model={model}
|
model={model}
|
||||||
setModel={setModel}
|
setModel={setModel}
|
||||||
modelList={MODEL_LIST}
|
modelList={MODEL_LIST}
|
||||||
|
provider={provider}
|
||||||
|
setProvider={setProvider}
|
||||||
providerList={providerList}
|
providerList={providerList}
|
||||||
provider={provider}
|
provider={provider}
|
||||||
setProvider={setProvider}
|
setProvider={setProvider}
|
||||||
|
@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
|
|||||||
import { chatStore } from '~/lib/stores/chat';
|
import { chatStore } from '~/lib/stores/chat';
|
||||||
import { workbenchStore } from '~/lib/stores/workbench';
|
import { workbenchStore } from '~/lib/stores/workbench';
|
||||||
import { fileModificationsToHTML } from '~/utils/diff';
|
import { fileModificationsToHTML } from '~/utils/diff';
|
||||||
import { DEFAULT_MODEL } from '~/utils/constants';
|
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
|
||||||
import { cubicEasingFn } from '~/utils/easings';
|
import { cubicEasingFn } from '~/utils/easings';
|
||||||
import { createScopedLogger, renderLogger } from '~/utils/logger';
|
import { createScopedLogger, renderLogger } from '~/utils/logger';
|
||||||
import { BaseChat } from './BaseChat';
|
import { BaseChat } from './BaseChat';
|
||||||
@ -75,6 +75,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|||||||
|
|
||||||
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
|
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
|
||||||
const [model, setModel] = useState(DEFAULT_MODEL);
|
const [model, setModel] = useState(DEFAULT_MODEL);
|
||||||
|
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
|
||||||
|
|
||||||
const { showChat } = useStore(chatStore);
|
const { showChat } = useStore(chatStore);
|
||||||
|
|
||||||
@ -188,7 +189,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|||||||
* manually reset the input and we'd have to manually pass in file attachments. However, those
|
* manually reset the input and we'd have to manually pass in file attachments. However, those
|
||||||
* aren't relevant here.
|
* aren't relevant here.
|
||||||
*/
|
*/
|
||||||
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
|
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${diff}\n\n${_input}` });
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* After sending a new message we reset all modifications since the model
|
* After sending a new message we reset all modifications since the model
|
||||||
@ -196,7 +197,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|||||||
*/
|
*/
|
||||||
workbenchStore.resetAllFileModifications();
|
workbenchStore.resetAllFileModifications();
|
||||||
} else {
|
} else {
|
||||||
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
|
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${_input}` });
|
||||||
}
|
}
|
||||||
|
|
||||||
setInput('');
|
setInput('');
|
||||||
@ -228,6 +229,8 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
|
|||||||
sendMessage={sendMessage}
|
sendMessage={sendMessage}
|
||||||
model={model}
|
model={model}
|
||||||
setModel={setModel}
|
setModel={setModel}
|
||||||
|
provider={provider}
|
||||||
|
setProvider={setProvider}
|
||||||
messageRef={messageRef}
|
messageRef={messageRef}
|
||||||
scrollRef={scrollRef}
|
scrollRef={scrollRef}
|
||||||
handleInputChange={handleInputChange}
|
handleInputChange={handleInputChange}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// @ts-nocheck
|
// @ts-nocheck
|
||||||
// Preventing TS checks with files presented in the video for a better presentation.
|
// Preventing TS checks with files presented in the video for a better presentation.
|
||||||
import { modificationsRegex } from '~/utils/diff';
|
import { modificationsRegex } from '~/utils/diff';
|
||||||
import { MODEL_REGEX } from '~/utils/constants';
|
import { MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
|
||||||
import { Markdown } from './Markdown';
|
import { Markdown } from './Markdown';
|
||||||
|
|
||||||
interface UserMessageProps {
|
interface UserMessageProps {
|
||||||
@ -17,5 +17,5 @@ export function UserMessage({ content }: UserMessageProps) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeUserMessage(content: string) {
|
function sanitizeUserMessage(content: string) {
|
||||||
return content.replace(modificationsRegex, '').replace(MODEL_REGEX, '').trim();
|
return content.replace(modificationsRegex, '').replace(MODEL_REGEX, 'Using: $1').replace(PROVIDER_REGEX, ' ($1)\n\n').trim();
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai';
|
|||||||
import { getModel } from '~/lib/.server/llm/model';
|
import { getModel } from '~/lib/.server/llm/model';
|
||||||
import { MAX_TOKENS } from './constants';
|
import { MAX_TOKENS } from './constants';
|
||||||
import { getSystemPrompt } from './prompts';
|
import { getSystemPrompt } from './prompts';
|
||||||
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
|
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
|
||||||
|
|
||||||
interface ToolResult<Name extends string, Args, Result> {
|
interface ToolResult<Name extends string, Args, Result> {
|
||||||
toolCallId: string;
|
toolCallId: string;
|
||||||
@ -24,18 +24,22 @@ export type Messages = Message[];
|
|||||||
|
|
||||||
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;
|
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;
|
||||||
|
|
||||||
function extractModelFromMessage(message: Message): { model: string; content: string } {
|
function extractPropertiesFromMessage(message: Message): { model: string; provider: string; content: string } {
|
||||||
const modelRegex = /^\[Model: (.*?)\]\n\n/;
|
// Extract model
|
||||||
const match = message.content.match(modelRegex);
|
const modelMatch = message.content.match(MODEL_REGEX);
|
||||||
|
const model = modelMatch ? modelMatch[1] : DEFAULT_MODEL;
|
||||||
|
|
||||||
if (match) {
|
// Extract provider
|
||||||
const model = match[1];
|
const providerMatch = message.content.match(PROVIDER_REGEX);
|
||||||
const content = message.content.replace(modelRegex, '');
|
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
|
||||||
return { model, content };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default model if not specified
|
// Remove model and provider lines from content
|
||||||
return { model: DEFAULT_MODEL, content: message.content };
|
const cleanedContent = message.content
|
||||||
|
.replace(MODEL_REGEX, '')
|
||||||
|
.replace(PROVIDER_REGEX, '')
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
return { model, provider, content: cleanedContent };
|
||||||
}
|
}
|
||||||
|
|
||||||
export function streamText(
|
export function streamText(
|
||||||
@ -45,26 +49,28 @@ export function streamText(
|
|||||||
apiKeys?: Record<string, string>
|
apiKeys?: Record<string, string>
|
||||||
) {
|
) {
|
||||||
let currentModel = DEFAULT_MODEL;
|
let currentModel = DEFAULT_MODEL;
|
||||||
|
let currentProvider = DEFAULT_PROVIDER;
|
||||||
|
|
||||||
const processedMessages = messages.map((message) => {
|
const processedMessages = messages.map((message) => {
|
||||||
if (message.role === 'user') {
|
if (message.role === 'user') {
|
||||||
const { model, content } = extractModelFromMessage(message);
|
const { model, provider, content } = extractPropertiesFromMessage(message);
|
||||||
if (model && MODEL_LIST.find((m) => m.name === model)) {
|
|
||||||
currentModel = model; // Update the current model
|
if (MODEL_LIST.find((m) => m.name === model)) {
|
||||||
|
currentModel = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
currentProvider = provider;
|
||||||
|
|
||||||
return { ...message, content };
|
return { ...message, content };
|
||||||
}
|
}
|
||||||
return message;
|
|
||||||
|
return message; // No changes for non-user messages
|
||||||
});
|
});
|
||||||
|
|
||||||
const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;
|
|
||||||
|
|
||||||
return _streamText({
|
return _streamText({
|
||||||
model: getModel(provider, currentModel, env, apiKeys),
|
model: getModel(currentProvider, currentModel, env, apiKeys),
|
||||||
system: getSystemPrompt(),
|
system: getSystemPrompt(),
|
||||||
maxTokens: MAX_TOKENS,
|
maxTokens: MAX_TOKENS,
|
||||||
// headers: {
|
|
||||||
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
|
|
||||||
// },
|
|
||||||
messages: convertToCoreMessages(processedMessages),
|
messages: convertToCoreMessages(processedMessages),
|
||||||
...options,
|
...options,
|
||||||
});
|
});
|
||||||
|
@ -4,6 +4,7 @@ export const WORK_DIR_NAME = 'project';
|
|||||||
export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
|
export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
|
||||||
export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications';
|
export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications';
|
||||||
export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/;
|
export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/;
|
||||||
|
export const PROVIDER_REGEX = /\[Provider: (.*?)\]\n\n/;
|
||||||
export const DEFAULT_MODEL = 'claude-3-5-sonnet-20240620';
|
export const DEFAULT_MODEL = 'claude-3-5-sonnet-20240620';
|
||||||
export const DEFAULT_PROVIDER = 'Anthropic';
|
export const DEFAULT_PROVIDER = 'Anthropic';
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ const staticModels: ModelInfo[] = [
|
|||||||
{ name: 'qwen/qwen-110b-chat', label: 'OpenRouter Qwen 110b Chat (OpenRouter)', provider: 'OpenRouter' },
|
{ name: 'qwen/qwen-110b-chat', label: 'OpenRouter Qwen 110b Chat (OpenRouter)', provider: 'OpenRouter' },
|
||||||
{ name: 'cohere/command', label: 'Cohere Command (OpenRouter)', provider: 'OpenRouter' },
|
{ name: 'cohere/command', label: 'Cohere Command (OpenRouter)', provider: 'OpenRouter' },
|
||||||
{ name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' },
|
{ name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' },
|
||||||
{ name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google'},
|
{ name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google' },
|
||||||
{ name: 'llama-3.1-70b-versatile', label: 'Llama 3.1 70b (Groq)', provider: 'Groq' },
|
{ name: 'llama-3.1-70b-versatile', label: 'Llama 3.1 70b (Groq)', provider: 'Groq' },
|
||||||
{ name: 'llama-3.1-8b-instant', label: 'Llama 3.1 8b (Groq)', provider: 'Groq' },
|
{ name: 'llama-3.1-8b-instant', label: 'Llama 3.1 8b (Groq)', provider: 'Groq' },
|
||||||
{ name: 'llama-3.2-11b-vision-preview', label: 'Llama 3.2 11b (Groq)', provider: 'Groq' },
|
{ name: 'llama-3.2-11b-vision-preview', label: 'Llama 3.2 11b (Groq)', provider: 'Groq' },
|
||||||
@ -82,32 +83,32 @@ async function getOllamaModels(): Promise<ModelInfo[]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function getOpenAILikeModels(): Promise<ModelInfo[]> {
|
async function getOpenAILikeModels(): Promise<ModelInfo[]> {
|
||||||
try {
|
try {
|
||||||
const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || "";
|
const base_url = import.meta.env.OPENAI_LIKE_API_BASE_URL || "";
|
||||||
if (!base_url) {
|
if (!base_url) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? "";
|
const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? "";
|
||||||
const response = await fetch(`${base_url}/models`, {
|
const response = await fetch(`${base_url}/models`, {
|
||||||
headers: {
|
headers: {
|
||||||
Authorization: `Bearer ${api_key}`,
|
Authorization: `Bearer ${api_key}`,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const res = await response.json() as any;
|
const res = await response.json() as any;
|
||||||
return res.data.map((model: any) => ({
|
return res.data.map((model: any) => ({
|
||||||
name: model.id,
|
name: model.id,
|
||||||
label: model.id,
|
label: model.id,
|
||||||
provider: 'OpenAILike',
|
provider: 'OpenAILike',
|
||||||
}));
|
}));
|
||||||
}catch (e) {
|
} catch (e) {
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
async function initializeModelList(): Promise<void> {
|
async function initializeModelList(): Promise<void> {
|
||||||
const ollamaModels = await getOllamaModels();
|
const ollamaModels = await getOllamaModels();
|
||||||
const openAiLikeModels = await getOpenAILikeModels();
|
const openAiLikeModels = await getOpenAILikeModels();
|
||||||
MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels];
|
MODEL_LIST = [...ollamaModels, ...openAiLikeModels, ...staticModels];
|
||||||
}
|
}
|
||||||
initializeModelList().then();
|
initializeModelList().then();
|
||||||
export { getOllamaModels, getOpenAILikeModels, initializeModelList };
|
export { getOllamaModels, getOpenAILikeModels, initializeModelList };
|
Loading…
Reference in New Issue
Block a user