This commit is contained in:
Anirban Kar 2024-12-06 16:58:04 +05:30
parent 5ead47992d
commit 7efad13284
3 changed files with 60 additions and 50 deletions

View File

@ -58,10 +58,15 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent }; return { model, provider, content: cleanedContent };
} }
export async function streamText(messages: Messages, env: Env, options?: StreamingOptions,apiKeys?: Record<string, string>) { export async function streamText(
messages: Messages,
env: Env,
options?: StreamingOptions,
apiKeys?: Record<string, string>,
) {
let currentModel = DEFAULT_MODEL; let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER.name; let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys||{}); const MODEL_LIST = await getModelList(apiKeys || {});
const processedMessages = messages.map((message) => { const processedMessages = messages.map((message) => {
if (message.role === 'user') { if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message); const { model, provider, content } = extractPropertiesFromMessage(message);
@ -69,6 +74,7 @@ export async function streamText(messages: Messages, env: Env, options?: Streami
if (MODEL_LIST.find((m) => m.name === model)) { if (MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; currentModel = model;
} }
currentProvider = provider; currentProvider = provider;
return { ...message, content }; return { ...message, content };

View File

@ -8,8 +8,8 @@ export async function action(args: ActionFunctionArgs) {
return chatAction(args); return chatAction(args);
} }
function parseCookies(cookieHeader:string) { function parseCookies(cookieHeader: string) {
const cookies:any = {}; const cookies: any = {};
// Split the cookie string by semicolons and spaces // Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim()); const items = cookieHeader.split(';').map((cookie) => cookie.trim());
@ -29,7 +29,7 @@ function parseCookies(cookieHeader:string) {
} }
async function chatAction({ context, request }: ActionFunctionArgs) { async function chatAction({ context, request }: ActionFunctionArgs) {
const { messages, model } = await request.json<{ const { messages } = await request.json<{
messages: Messages; messages: Messages;
model: string; model: string;
}>(); }>();
@ -37,7 +37,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
const cookieHeader = request.headers.get('Cookie'); const cookieHeader = request.headers.get('Cookie');
// Parse the cookie's value (returns an object or null if no cookie exists) // Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader||"").apiKeys || '{}'); const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const stream = new SwitchableStream(); const stream = new SwitchableStream();
@ -60,7 +60,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
messages.push({ role: 'assistant', content }); messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT }); messages.push({ role: 'user', content: CONTINUE_PROMPT });
const result = await streamText(messages, context.cloudflare.env, options,apiKeys); const result = await streamText(messages, context.cloudflare.env, options, apiKeys);
return stream.switchSource(result.toAIStream()); return stream.switchSource(result.toAIStream());
}, },
@ -76,7 +76,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
contentType: 'text/plain; charset=utf-8', contentType: 'text/plain; charset=utf-8',
}, },
}); });
} catch (error:any) { } catch (error: any) {
console.log(error); console.log(error);
if (error.message?.includes('API key')) { if (error.message?.includes('API key')) {

View File

@ -295,7 +295,6 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(
export let MODEL_LIST: ModelInfo[] = [...staticModels]; export let MODEL_LIST: ModelInfo[] = [...staticModels];
export async function getModelList(apiKeys: Record<string, string>) { export async function getModelList(apiKeys: Record<string, string>) {
MODEL_LIST = [ MODEL_LIST = [
...( ...(
@ -312,13 +311,14 @@ export async function getModelList(apiKeys: Record<string, string>) {
async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> { async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> {
try { try {
let baseUrl = import.meta.env.TOGETHER_API_BASE_URL || ''; const baseUrl = import.meta.env.TOGETHER_API_BASE_URL || '';
let provider='Together' const provider = 'Together';
if (!baseUrl) { if (!baseUrl) {
return []; return [];
} }
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
if (apiKeys && apiKeys[provider]) { if (apiKeys && apiKeys[provider]) {
apiKey = apiKeys[provider]; apiKey = apiKeys[provider];
@ -334,21 +334,21 @@ async function getTogetherModels(apiKeys?: Record<string, string>): Promise<Mode
}, },
}); });
const res = (await response.json()) as any; const res = (await response.json()) as any;
let data: any[] = (res || []).filter((model: any) => model.type=='chat') const data: any[] = (res || []).filter((model: any) => model.type == 'chat');
return data.map((m: any) => ({ return data.map((m: any) => ({
name: m.id, name: m.id,
label: `${m.display_name} - in:$${(m.pricing.input).toFixed( label: `${m.display_name} - in:$${m.pricing.input.toFixed(
2, 2,
)} out:$${(m.pricing.output).toFixed(2)} - context ${Math.floor(m.context_length / 1000)}k`, )} out:$${m.pricing.output.toFixed(2)} - context ${Math.floor(m.context_length / 1000)}k`,
provider: provider, provider,
maxTokenAllowed: 8000, maxTokenAllowed: 8000,
})); }));
} catch (e) { } catch (e) {
console.error('Error getting OpenAILike models:', e); console.error('Error getting OpenAILike models:', e);
return []; return [];
}
} }
}
const getOllamaBaseUrl = () => { const getOllamaBaseUrl = () => {
const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434'; const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
@ -396,11 +396,13 @@ async function getOpenAILikeModels(): Promise<ModelInfo[]> {
if (!baseUrl) { if (!baseUrl) {
return []; return [];
} }
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''; let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
let apikeys = JSON.parse(Cookies.get('apiKeys')||'{}') const apikeys = JSON.parse(Cookies.get('apiKeys') || '{}');
if (apikeys && apikeys['OpenAILike']){
apiKey = apikeys['OpenAILike']; if (apikeys && apikeys.OpenAILike) {
apiKey = apikeys.OpenAILike;
} }
const response = await fetch(`${baseUrl}/models`, { const response = await fetch(`${baseUrl}/models`, {
@ -458,6 +460,7 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
if (typeof window === 'undefined') { if (typeof window === 'undefined') {
return []; return [];
} }
try { try {
const baseUrl = import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234'; const baseUrl = import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234';
const response = await fetch(`${baseUrl}/v1/models`); const response = await fetch(`${baseUrl}/v1/models`);
@ -476,6 +479,7 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
async function initializeModelList(): Promise<ModelInfo[]> { async function initializeModelList(): Promise<ModelInfo[]> {
let apiKeys: Record<string, string> = {}; let apiKeys: Record<string, string> = {};
try { try {
const storedApiKeys = Cookies.get('apiKeys'); const storedApiKeys = Cookies.get('apiKeys');
@ -486,9 +490,8 @@ async function initializeModelList(): Promise<ModelInfo[]> {
apiKeys = parsedKeys; apiKeys = parsedKeys;
} }
} }
} catch (error: any) {
} catch (error) { console.warn(`Failed to fetch apikeys from cookies:${error?.message}`);
} }
MODEL_LIST = [ MODEL_LIST = [
...( ...(
@ -500,6 +503,7 @@ async function initializeModelList(): Promise<ModelInfo[]> {
).flat(), ).flat(),
...staticModels, ...staticModels,
]; ];
return MODEL_LIST; return MODEL_LIST;
} }