mirror of
https://github.com/stackblitz/bolt.new
synced 2025-02-06 04:48:04 +00:00
feat(Dynamic Models): together AI Dynamic Models
This commit is contained in:
parent
115dcbb3bd
commit
1589d2a8f5
@ -1,11 +1,8 @@
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-nocheck – TODO: Provider proper types
|
||||
|
||||
import { convertToCoreMessages, streamText as _streamText } from 'ai';
|
||||
import { getModel } from '~/lib/.server/llm/model';
|
||||
import { MAX_TOKENS } from './constants';
|
||||
import { getSystemPrompt } from './prompts';
|
||||
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
|
||||
import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
|
||||
|
||||
interface ToolResult<Name extends string, Args, Result> {
|
||||
toolCallId: string;
|
||||
@ -32,7 +29,7 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
|
||||
|
||||
// Extract provider
|
||||
const providerMatch = message.content.match(PROVIDER_REGEX);
|
||||
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
|
||||
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER.name;
|
||||
|
||||
// Remove model and provider lines from content
|
||||
const cleanedContent = message.content.replace(MODEL_REGEX, '').replace(PROVIDER_REGEX, '').trim();
|
||||
@ -40,10 +37,10 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
|
||||
return { model, provider, content: cleanedContent };
|
||||
}
|
||||
|
||||
export function streamText(messages: Messages, env: Env, options?: StreamingOptions, apiKeys?: Record<string, string>) {
|
||||
export async function streamText(messages: Messages, env: Env, options?: StreamingOptions,apiKeys?: Record<string, string>) {
|
||||
let currentModel = DEFAULT_MODEL;
|
||||
let currentProvider = DEFAULT_PROVIDER;
|
||||
|
||||
let currentProvider = DEFAULT_PROVIDER.name;
|
||||
const MODEL_LIST = await getModelList(apiKeys||{});
|
||||
const processedMessages = messages.map((message) => {
|
||||
if (message.role === 'user') {
|
||||
const { model, provider, content } = extractPropertiesFromMessage(message);
|
||||
@ -51,7 +48,6 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
|
||||
if (MODEL_LIST.find((m) => m.name === model)) {
|
||||
currentModel = model;
|
||||
}
|
||||
|
||||
currentProvider = provider;
|
||||
|
||||
return { ...message, content };
|
||||
@ -65,10 +61,10 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
|
||||
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
|
||||
|
||||
return _streamText({
|
||||
model: getModel(currentProvider, currentModel, env, apiKeys),
|
||||
model: getModel(currentProvider, currentModel, env, apiKeys) as any,
|
||||
system: getSystemPrompt(),
|
||||
maxTokens: dynamicMaxTokens,
|
||||
messages: convertToCoreMessages(processedMessages),
|
||||
messages: convertToCoreMessages(processedMessages as any),
|
||||
...options,
|
||||
});
|
||||
}
|
||||
|
@ -1,6 +1,3 @@
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-nocheck – TODO: Provider proper types
|
||||
|
||||
import { type ActionFunctionArgs } from '@remix-run/cloudflare';
|
||||
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
|
||||
import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
|
||||
@ -11,8 +8,8 @@ export async function action(args: ActionFunctionArgs) {
|
||||
return chatAction(args);
|
||||
}
|
||||
|
||||
function parseCookies(cookieHeader) {
|
||||
const cookies = {};
|
||||
function parseCookies(cookieHeader:string) {
|
||||
const cookies:any = {};
|
||||
|
||||
// Split the cookie string by semicolons and spaces
|
||||
const items = cookieHeader.split(';').map((cookie) => cookie.trim());
|
||||
@ -39,14 +36,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
const cookieHeader = request.headers.get('Cookie');
|
||||
|
||||
// Parse the cookie's value (returns an object or null if no cookie exists)
|
||||
const apiKeys = JSON.parse(parseCookies(cookieHeader).apiKeys || '{}');
|
||||
const apiKeys = JSON.parse(parseCookies(cookieHeader||"").apiKeys || '{}');
|
||||
|
||||
const stream = new SwitchableStream();
|
||||
|
||||
try {
|
||||
const options: StreamingOptions = {
|
||||
toolChoice: 'none',
|
||||
apiKeys,
|
||||
onFinish: async ({ text: content, finishReason }) => {
|
||||
if (finishReason !== 'length') {
|
||||
return stream.close();
|
||||
@ -63,7 +59,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
messages.push({ role: 'assistant', content });
|
||||
messages.push({ role: 'user', content: CONTINUE_PROMPT });
|
||||
|
||||
const result = await streamText(messages, context.cloudflare.env, options);
|
||||
const result = await streamText(messages, context.cloudflare.env, options,apiKeys);
|
||||
|
||||
return stream.switchSource(result.toAIStream());
|
||||
},
|
||||
@ -79,7 +75,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
|
||||
contentType: 'text/plain; charset=utf-8',
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
} catch (error:any) {
|
||||
console.log(error);
|
||||
|
||||
if (error.message?.includes('API key')) {
|
||||
|
@ -3,7 +3,7 @@ import type { ModelInfo } from '~/utils/types';
|
||||
export type ProviderInfo = {
|
||||
staticModels: ModelInfo[];
|
||||
name: string;
|
||||
getDynamicModels?: () => Promise<ModelInfo[]>;
|
||||
getDynamicModels?: (apiKeys?: Record<string, string>) => Promise<ModelInfo[]>;
|
||||
getApiKeyLink?: string;
|
||||
labelForGetApiKey?: string;
|
||||
icon?: string;
|
||||
|
@ -1,3 +1,5 @@
|
||||
import Cookies from 'js-cookie';
|
||||
import { parseCookies } from './parseCookies';
|
||||
import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
|
||||
import type { ProviderInfo } from '~/types/model';
|
||||
|
||||
@ -262,6 +264,7 @@ const PROVIDER_LIST: ProviderInfo[] = [
|
||||
},
|
||||
{
|
||||
name: 'Together',
|
||||
getDynamicModels: getTogetherModels,
|
||||
staticModels: [
|
||||
{
|
||||
name: 'Qwen/Qwen2.5-Coder-32B-Instruct',
|
||||
@ -293,6 +296,61 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(
|
||||
|
||||
export let MODEL_LIST: ModelInfo[] = [...staticModels];
|
||||
|
||||
|
||||
export async function getModelList(apiKeys: Record<string, string>) {
|
||||
MODEL_LIST = [
|
||||
...(
|
||||
await Promise.all(
|
||||
PROVIDER_LIST.filter(
|
||||
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
|
||||
).map((p) => p.getDynamicModels(apiKeys)),
|
||||
)
|
||||
).flat(),
|
||||
...staticModels,
|
||||
];
|
||||
return MODEL_LIST;
|
||||
}
|
||||
|
||||
async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> {
|
||||
try {
|
||||
let baseUrl = import.meta.env.TOGETHER_API_BASE_URL || '';
|
||||
let provider='Together'
|
||||
|
||||
if (!baseUrl) {
|
||||
return [];
|
||||
}
|
||||
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''
|
||||
|
||||
if (apiKeys && apiKeys[provider]) {
|
||||
apiKey = apiKeys[provider];
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const response = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
});
|
||||
const res = (await response.json()) as any;
|
||||
let data: any[] = (res || []).filter((model: any) => model.type=='chat')
|
||||
return data.map((m: any) => ({
|
||||
name: m.id,
|
||||
label: `${m.display_name} - in:$${(m.pricing.input).toFixed(
|
||||
2,
|
||||
)} out:$${(m.pricing.output).toFixed(2)} - context ${Math.floor(m.context_length / 1000)}k`,
|
||||
provider: provider,
|
||||
maxTokenAllowed: 8000,
|
||||
}));
|
||||
} catch (e) {
|
||||
console.error('Error getting OpenAILike models:', e);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const getOllamaBaseUrl = () => {
|
||||
const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
|
||||
|
||||
@ -339,8 +397,13 @@ async function getOpenAILikeModels(): Promise<ModelInfo[]> {
|
||||
if (!baseUrl) {
|
||||
return [];
|
||||
}
|
||||
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
|
||||
|
||||
let apikeys = JSON.parse(Cookies.get('apiKeys')||'{}')
|
||||
if (apikeys && apikeys['OpenAILike']){
|
||||
apiKey = apikeys['OpenAILike'];
|
||||
}
|
||||
|
||||
const apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
|
||||
const response = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
@ -396,7 +459,6 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
|
||||
if (typeof window === 'undefined') {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const baseUrl = import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234';
|
||||
const response = await fetch(`${baseUrl}/v1/models`);
|
||||
@ -414,12 +476,27 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
|
||||
}
|
||||
|
||||
async function initializeModelList(): Promise<ModelInfo[]> {
|
||||
let apiKeys: Record<string, string> = {};
|
||||
try {
|
||||
const storedApiKeys = Cookies.get('apiKeys');
|
||||
|
||||
if (storedApiKeys) {
|
||||
const parsedKeys = JSON.parse(storedApiKeys);
|
||||
|
||||
if (typeof parsedKeys === 'object' && parsedKeys !== null) {
|
||||
apiKeys = parsedKeys;
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
|
||||
}
|
||||
MODEL_LIST = [
|
||||
...(
|
||||
await Promise.all(
|
||||
PROVIDER_LIST.filter(
|
||||
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
|
||||
).map((p) => p.getDynamicModels()),
|
||||
).map((p) => p.getDynamicModels(apiKeys)),
|
||||
)
|
||||
).flat(),
|
||||
...staticModels,
|
||||
|
19
app/utils/parseCookies.ts
Normal file
19
app/utils/parseCookies.ts
Normal file
@ -0,0 +1,19 @@
|
||||
export function parseCookies(cookieHeader: string) {
|
||||
const cookies: any = {};
|
||||
|
||||
// Split the cookie string by semicolons and spaces
|
||||
const items = cookieHeader.split(';').map((cookie) => cookie.trim());
|
||||
|
||||
items.forEach((item) => {
|
||||
const [name, ...rest] = item.split('=');
|
||||
|
||||
if (name && rest) {
|
||||
// Decode the name and value, and join value parts in case it contains '='
|
||||
const decodedName = decodeURIComponent(name.trim());
|
||||
const decodedValue = decodeURIComponent(rest.join('=').trim());
|
||||
cookies[decodedName] = decodedValue;
|
||||
}
|
||||
});
|
||||
|
||||
return cookies;
|
||||
}
|
@ -28,7 +28,7 @@ export default defineConfig((config) => {
|
||||
chrome129IssuePlugin(),
|
||||
config.mode === 'production' && optimizeCssModules({ apply: 'build' }),
|
||||
],
|
||||
envPrefix:["VITE_","OPENAI_LIKE_API_","OLLAMA_API_BASE_URL","LMSTUDIO_API_BASE_URL"],
|
||||
envPrefix: ["VITE_", "OPENAI_LIKE_API_", "OLLAMA_API_BASE_URL", "LMSTUDIO_API_BASE_URL","TOGETHER_API_BASE_URL"],
|
||||
css: {
|
||||
preprocessorOptions: {
|
||||
scss: {
|
||||
|
Loading…
Reference in New Issue
Block a user