Merge pull request #513 from thecodacus/together-ai-dynamic-model-list

feat(Dynamic Models): Added Together AI Dynamic Models
This commit is contained in:
Anirban Kar 2024-12-06 16:59:57 +05:30 committed by GitHub
commit 2e49905ed9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 26 deletions

View File

@ -1,11 +1,8 @@
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-nocheck TODO: Provider proper types
import { convertToCoreMessages, streamText as _streamText } from 'ai'; import { convertToCoreMessages, streamText as _streamText } from 'ai';
import { getModel } from '~/lib/.server/llm/model'; import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants'; import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts'; import { getSystemPrompt } from './prompts';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants'; import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
interface ToolResult<Name extends string, Args, Result> { interface ToolResult<Name extends string, Args, Result> {
toolCallId: string; toolCallId: string;
@ -43,7 +40,7 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
* Extract provider * Extract provider
* const providerMatch = message.content.match(PROVIDER_REGEX); * const providerMatch = message.content.match(PROVIDER_REGEX);
*/ */
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER; const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER.name;
const cleanedContent = Array.isArray(message.content) const cleanedContent = Array.isArray(message.content)
? message.content.map((item) => { ? message.content.map((item) => {
@ -61,10 +58,15 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent }; return { model, provider, content: cleanedContent };
} }
export function streamText(messages: Messages, env: Env, options?: StreamingOptions, apiKeys?: Record<string, string>) { export async function streamText(
messages: Messages,
env: Env,
options?: StreamingOptions,
apiKeys?: Record<string, string>,
) {
let currentModel = DEFAULT_MODEL; let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER; let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys || {});
const processedMessages = messages.map((message) => { const processedMessages = messages.map((message) => {
if (message.role === 'user') { if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message); const { model, provider, content } = extractPropertiesFromMessage(message);
@ -86,10 +88,10 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS; const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
return _streamText({ return _streamText({
...options, model: getModel(currentProvider, currentModel, env, apiKeys) as any,
model: getModel(currentProvider, currentModel, env, apiKeys),
system: getSystemPrompt(), system: getSystemPrompt(),
maxTokens: dynamicMaxTokens, maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages), messages: convertToCoreMessages(processedMessages as any),
...options,
}); });
} }

View File

@ -1,6 +1,3 @@
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-nocheck TODO: Provider proper types
import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts'; import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
@ -11,8 +8,8 @@ export async function action(args: ActionFunctionArgs) {
return chatAction(args); return chatAction(args);
} }
function parseCookies(cookieHeader) { function parseCookies(cookieHeader: string) {
const cookies = {}; const cookies: any = {};
// Split the cookie string by semicolons and spaces // Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim()); const items = cookieHeader.split(';').map((cookie) => cookie.trim());
@ -32,7 +29,7 @@ function parseCookies(cookieHeader) {
} }
async function chatAction({ context, request }: ActionFunctionArgs) { async function chatAction({ context, request }: ActionFunctionArgs) {
const { messages, model } = await request.json<{ const { messages } = await request.json<{
messages: Messages; messages: Messages;
model: string; model: string;
}>(); }>();
@ -40,15 +37,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
const cookieHeader = request.headers.get('Cookie'); const cookieHeader = request.headers.get('Cookie');
// Parse the cookie's value (returns an object or null if no cookie exists) // Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader).apiKeys || '{}'); const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const stream = new SwitchableStream(); const stream = new SwitchableStream();
try { try {
const options: StreamingOptions = { const options: StreamingOptions = {
toolChoice: 'none', toolChoice: 'none',
apiKeys,
model,
onFinish: async ({ text: content, finishReason }) => { onFinish: async ({ text: content, finishReason }) => {
if (finishReason !== 'length') { if (finishReason !== 'length') {
return stream.close(); return stream.close();
@ -65,7 +60,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
messages.push({ role: 'assistant', content }); messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT }); messages.push({ role: 'user', content: CONTINUE_PROMPT });
const result = await streamText(messages, context.cloudflare.env, options); const result = await streamText(messages, context.cloudflare.env, options, apiKeys);
return stream.switchSource(result.toAIStream()); return stream.switchSource(result.toAIStream());
}, },
@ -81,7 +76,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
contentType: 'text/plain; charset=utf-8', contentType: 'text/plain; charset=utf-8',
}, },
}); });
} catch (error) { } catch (error: any) {
console.log(error); console.log(error);
if (error.message?.includes('API key')) { if (error.message?.includes('API key')) {

View File

@ -3,7 +3,7 @@ import type { ModelInfo } from '~/utils/types';
export type ProviderInfo = { export type ProviderInfo = {
staticModels: ModelInfo[]; staticModels: ModelInfo[];
name: string; name: string;
getDynamicModels?: () => Promise<ModelInfo[]>; getDynamicModels?: (apiKeys?: Record<string, string>) => Promise<ModelInfo[]>;
getApiKeyLink?: string; getApiKeyLink?: string;
labelForGetApiKey?: string; labelForGetApiKey?: string;
icon?: string; icon?: string;

View File

@ -1,3 +1,4 @@
import Cookies from 'js-cookie';
import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types'; import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
import type { ProviderInfo } from '~/types/model'; import type { ProviderInfo } from '~/types/model';
@ -262,6 +263,7 @@ const PROVIDER_LIST: ProviderInfo[] = [
}, },
{ {
name: 'Together', name: 'Together',
getDynamicModels: getTogetherModels,
staticModels: [ staticModels: [
{ {
name: 'Qwen/Qwen2.5-Coder-32B-Instruct', name: 'Qwen/Qwen2.5-Coder-32B-Instruct',
@ -293,6 +295,61 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(
export let MODEL_LIST: ModelInfo[] = [...staticModels]; export let MODEL_LIST: ModelInfo[] = [...staticModels];
export async function getModelList(apiKeys: Record<string, string>) {
MODEL_LIST = [
...(
await Promise.all(
PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels(apiKeys)),
)
).flat(),
...staticModels,
];
return MODEL_LIST;
}
async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> {
try {
const baseUrl = import.meta.env.TOGETHER_API_BASE_URL || '';
const provider = 'Together';
if (!baseUrl) {
return [];
}
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
if (apiKeys && apiKeys[provider]) {
apiKey = apiKeys[provider];
}
if (!apiKey) {
return [];
}
const response = await fetch(`${baseUrl}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
const res = (await response.json()) as any;
const data: any[] = (res || []).filter((model: any) => model.type == 'chat');
return data.map((m: any) => ({
name: m.id,
label: `${m.display_name} - in:$${m.pricing.input.toFixed(
2,
)} out:$${m.pricing.output.toFixed(2)} - context ${Math.floor(m.context_length / 1000)}k`,
provider,
maxTokenAllowed: 8000,
}));
} catch (e) {
console.error('Error getting OpenAILike models:', e);
return [];
}
}
const getOllamaBaseUrl = () => { const getOllamaBaseUrl = () => {
const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434'; const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
@ -340,7 +397,14 @@ async function getOpenAILikeModels(): Promise<ModelInfo[]> {
return []; return [];
} }
const apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''; let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
const apikeys = JSON.parse(Cookies.get('apiKeys') || '{}');
if (apikeys && apikeys.OpenAILike) {
apiKey = apikeys.OpenAILike;
}
const response = await fetch(`${baseUrl}/models`, { const response = await fetch(`${baseUrl}/models`, {
headers: { headers: {
Authorization: `Bearer ${apiKey}`, Authorization: `Bearer ${apiKey}`,
@ -414,16 +478,32 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
} }
async function initializeModelList(): Promise<ModelInfo[]> { async function initializeModelList(): Promise<ModelInfo[]> {
let apiKeys: Record<string, string> = {};
try {
const storedApiKeys = Cookies.get('apiKeys');
if (storedApiKeys) {
const parsedKeys = JSON.parse(storedApiKeys);
if (typeof parsedKeys === 'object' && parsedKeys !== null) {
apiKeys = parsedKeys;
}
}
} catch (error: any) {
console.warn(`Failed to fetch apikeys from cookies:${error?.message}`);
}
MODEL_LIST = [ MODEL_LIST = [
...( ...(
await Promise.all( await Promise.all(
PROVIDER_LIST.filter( PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels, (p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels()), ).map((p) => p.getDynamicModels(apiKeys)),
) )
).flat(), ).flat(),
...staticModels, ...staticModels,
]; ];
return MODEL_LIST; return MODEL_LIST;
} }

View File

@ -27,7 +27,7 @@ export default defineConfig((config) => {
chrome129IssuePlugin(), chrome129IssuePlugin(),
config.mode === 'production' && optimizeCssModules({ apply: 'build' }), config.mode === 'production' && optimizeCssModules({ apply: 'build' }),
], ],
envPrefix:["VITE_","OPENAI_LIKE_API_","OLLAMA_API_BASE_URL","LMSTUDIO_API_BASE_URL"], envPrefix: ["VITE_", "OPENAI_LIKE_API_", "OLLAMA_API_BASE_URL", "LMSTUDIO_API_BASE_URL","TOGETHER_API_BASE_URL"],
css: { css: {
preprocessorOptions: { preprocessorOptions: {
scss: { scss: {