updated to adapth baseurl setup

This commit is contained in:
Anirban Kar 2024-12-11 14:02:21 +05:30
parent b4d0597120
commit 5d4b860c94
12 changed files with 122 additions and 71 deletions

View File

@ -17,7 +17,6 @@ import Cookies from 'js-cookie';
import * as Tooltip from '@radix-ui/react-tooltip'; import * as Tooltip from '@radix-ui/react-tooltip';
import styles from './BaseChat.module.scss'; import styles from './BaseChat.module.scss';
import type { ProviderInfo } from '~/utils/types';
import { ExportChatButton } from '~/components/chat/chatExportAndImport/ExportChatButton'; import { ExportChatButton } from '~/components/chat/chatExportAndImport/ExportChatButton';
import { ImportButtons } from '~/components/chat/chatExportAndImport/ImportButtons'; import { ImportButtons } from '~/components/chat/chatExportAndImport/ImportButtons';
import { ExamplePrompts } from '~/components/chat/ExamplePrompts'; import { ExamplePrompts } from '~/components/chat/ExamplePrompts';
@ -26,6 +25,7 @@ import GitCloneButton from './GitCloneButton';
import FilePreview from './FilePreview'; import FilePreview from './FilePreview';
import { ModelSelector } from '~/components/chat/ModelSelector'; import { ModelSelector } from '~/components/chat/ModelSelector';
import { SpeechRecognitionButton } from '~/components/chat/SpeechRecognition'; import { SpeechRecognitionButton } from '~/components/chat/SpeechRecognition';
import type { IProviderSetting, ProviderInfo } from '~/types/model';
const TEXTAREA_MIN_HEIGHT = 76; const TEXTAREA_MIN_HEIGHT = 76;
@ -131,7 +131,26 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
Cookies.remove('apiKeys'); Cookies.remove('apiKeys');
} }
initializeModelList().then((modelList) => { let providerSettings: Record<string, IProviderSetting> | undefined = undefined;
try {
const savedProviderSettings = Cookies.get('providers');
if (savedProviderSettings) {
const parsedProviderSettings = JSON.parse(savedProviderSettings);
if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) {
providerSettings = parsedProviderSettings;
}
}
} catch (error) {
console.error('Error loading Provider Settings from cookies:', error);
// Clear invalid cookie data
Cookies.remove('providers');
}
initializeModelList(providerSettings).then((modelList) => {
setModelList(modelList); setModelList(modelList);
}); });

View File

@ -17,9 +17,9 @@ import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger'; import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat'; import { BaseChat } from './BaseChat';
import Cookies from 'js-cookie'; import Cookies from 'js-cookie';
import type { ProviderInfo } from '~/utils/types';
import { debounce } from '~/utils/debounce'; import { debounce } from '~/utils/debounce';
import { useSettings } from '~/lib/hooks/useSettings'; import { useSettings } from '~/lib/hooks/useSettings';
import type { ProviderInfo } from '~/types/model';
const toastAnimation = cssTransition({ const toastAnimation = cssTransition({
enter: 'animated fadeInRight', enter: 'animated fadeInRight',

View File

@ -1,7 +1,8 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Switch } from '~/components/ui/Switch'; import { Switch } from '~/components/ui/Switch';
import { useSettings } from '~/lib/hooks/useSettings'; import { useSettings } from '~/lib/hooks/useSettings';
import { LOCAL_PROVIDERS, URL_CONFIGURABLE_PROVIDERS, type IProviderConfig } from '~/lib/stores/settings'; import { LOCAL_PROVIDERS, URL_CONFIGURABLE_PROVIDERS } from '~/lib/stores/settings';
import type { IProviderConfig } from '~/types/model';
export default function ProvidersTab() { export default function ProvidersTab() {
const { providers, updateProviderSettings, isLocalModel } = useSettings(); const { providers, updateProviderSettings, isLocalModel } = useSettings();

View File

@ -11,6 +11,7 @@ import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { createMistral } from '@ai-sdk/mistral'; import { createMistral } from '@ai-sdk/mistral';
import { createCohere } from '@ai-sdk/cohere'; import { createCohere } from '@ai-sdk/cohere';
import type { LanguageModelV1 } from 'ai'; import type { LanguageModelV1 } from 'ai';
import type { IProviderSetting } from '~/types/model';
export const DEFAULT_NUM_CTX = process.env.DEFAULT_NUM_CTX ? parseInt(process.env.DEFAULT_NUM_CTX, 10) : 32768; export const DEFAULT_NUM_CTX = process.env.DEFAULT_NUM_CTX ? parseInt(process.env.DEFAULT_NUM_CTX, 10) : 32768;
@ -127,14 +128,20 @@ export function getXAIModel(apiKey: OptionalApiKey, model: string) {
return openai(model); return openai(model);
} }
export function getModel(provider: string, model: string, env: Env, apiKeys?: Record<string, string>) { export function getModel(
provider: string,
model: string,
env: Env,
apiKeys?: Record<string, string>,
providerSettings?: Record<string, IProviderSetting>,
) {
/* /*
* let apiKey; // Declare first * let apiKey; // Declare first
* let baseURL; * let baseURL;
*/ */
const apiKey = getAPIKey(env, provider, apiKeys); // Then assign const apiKey = getAPIKey(env, provider, apiKeys); // Then assign
const baseURL = getBaseURL(env, provider); const baseURL = providerSettings?.[provider].baseUrl || getBaseURL(env, provider);
switch (provider) { switch (provider) {
case 'Anthropic': case 'Anthropic':

View File

@ -3,6 +3,7 @@ import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants'; import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts'; import { getSystemPrompt } from './prompts';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants'; import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
import type { IProviderSetting } from '~/types/model';
interface ToolResult<Name extends string, Args, Result> { interface ToolResult<Name extends string, Args, Result> {
toolCallId: string; toolCallId: string;
@ -58,15 +59,17 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent }; return { model, provider, content: cleanedContent };
} }
export async function streamText( export async function streamText(props: {
messages: Messages, messages: Messages;
env: Env, env: Env;
options?: StreamingOptions, options?: StreamingOptions;
apiKeys?: Record<string, string>, apiKeys?: Record<string, string>;
) { providerSettings?: Record<string, IProviderSetting>;
}) {
const { messages, env, options, apiKeys, providerSettings } = props;
let currentModel = DEFAULT_MODEL; let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER.name; let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys || {}); const MODEL_LIST = await getModelList(apiKeys || {}, providerSettings);
const processedMessages = messages.map((message) => { const processedMessages = messages.map((message) => {
if (message.role === 'user') { if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message); const { model, provider, content } = extractPropertiesFromMessage(message);
@ -88,7 +91,7 @@ export async function streamText(
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS; const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
return _streamText({ return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys) as any, model: getModel(currentProvider, currentModel, env, apiKeys, providerSettings) as any,
system: getSystemPrompt(), system: getSystemPrompt(),
maxTokens: dynamicMaxTokens, maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages as any), messages: convertToCoreMessages(processedMessages as any),

View File

@ -1,14 +1,8 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { import { isDebugMode, isLocalModelsEnabled, LOCAL_PROVIDERS, providersStore } from '~/lib/stores/settings';
isDebugMode,
isLocalModelsEnabled,
LOCAL_PROVIDERS,
providersStore,
type IProviderSetting,
} from '~/lib/stores/settings';
import { useCallback, useEffect, useState } from 'react'; import { useCallback, useEffect, useState } from 'react';
import Cookies from 'js-cookie'; import Cookies from 'js-cookie';
import type { ProviderInfo } from '~/utils/types'; import type { IProviderSetting, ProviderInfo } from '~/types/model';
export function useSettings() { export function useSettings() {
const providers = useStore(providersStore); const providers = useStore(providersStore);

View File

@ -1,7 +1,7 @@
import { atom, map } from 'nanostores'; import { atom, map } from 'nanostores';
import { workbenchStore } from './workbench'; import { workbenchStore } from './workbench';
import type { ProviderInfo } from '~/utils/types';
import { PROVIDER_LIST } from '~/utils/constants'; import { PROVIDER_LIST } from '~/utils/constants';
import type { IProviderConfig } from '~/types/model';
export interface Shortcut { export interface Shortcut {
key: string; key: string;
@ -17,14 +17,6 @@ export interface Shortcuts {
toggleTerminal: Shortcut; toggleTerminal: Shortcut;
} }
export interface IProviderSetting {
enabled?: boolean;
baseUrl?: string;
}
export type IProviderConfig = ProviderInfo & {
settings: IProviderSetting;
};
export const URL_CONFIGURABLE_PROVIDERS = ['Ollama', 'LMStudio', 'OpenAILike']; export const URL_CONFIGURABLE_PROVIDERS = ['Ollama', 'LMStudio', 'OpenAILike'];
export const LOCAL_PROVIDERS = ['OpenAILike', 'LMStudio', 'Ollama']; export const LOCAL_PROVIDERS = ['OpenAILike', 'LMStudio', 'Ollama'];

View File

@ -3,6 +3,7 @@ import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts'; import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
import SwitchableStream from '~/lib/.server/llm/switchable-stream'; import SwitchableStream from '~/lib/.server/llm/switchable-stream';
import type { IProviderSetting } from '~/types/model';
export async function action(args: ActionFunctionArgs) { export async function action(args: ActionFunctionArgs) {
return chatAction(args); return chatAction(args);
@ -38,6 +39,9 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
// Parse the cookie's value (returns an object or null if no cookie exists) // Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const providerSettings: Record<string, IProviderSetting> = JSON.parse(
parseCookies(cookieHeader || '').providers || '{}',
);
const stream = new SwitchableStream(); const stream = new SwitchableStream();
@ -60,13 +64,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
messages.push({ role: 'assistant', content }); messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT }); messages.push({ role: 'user', content: CONTINUE_PROMPT });
const result = await streamText(messages, context.cloudflare.env, options, apiKeys); const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, providerSettings });
return stream.switchSource(result.toAIStream()); return stream.switchSource(result.toAIStream());
}, },
}; };
const result = await streamText(messages, context.cloudflare.env, options, apiKeys); const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, providerSettings });
stream.switchSource(result.toAIStream()); stream.switchSource(result.toAIStream());

View File

@ -2,7 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { StreamingTextResponse, parseStreamPart } from 'ai'; import { StreamingTextResponse, parseStreamPart } from 'ai';
import { streamText } from '~/lib/.server/llm/stream-text'; import { streamText } from '~/lib/.server/llm/stream-text';
import { stripIndents } from '~/utils/stripIndent'; import { stripIndents } from '~/utils/stripIndent';
import type { ProviderInfo } from '~/types/model'; import type { IProviderSetting, ProviderInfo } from '~/types/model';
const encoder = new TextEncoder(); const encoder = new TextEncoder();
const decoder = new TextDecoder(); const decoder = new TextDecoder();
@ -11,8 +11,28 @@ export async function action(args: ActionFunctionArgs) {
return enhancerAction(args); return enhancerAction(args);
} }
function parseCookies(cookieHeader: string) {
const cookies: any = {};
// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(';').map((cookie) => cookie.trim());
items.forEach((item) => {
const [name, ...rest] = item.split('=');
if (name && rest) {
// Decode the name and value, and join value parts in case it contains '='
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join('=').trim());
cookies[decodedName] = decodedValue;
}
});
return cookies;
}
async function enhancerAction({ context, request }: ActionFunctionArgs) { async function enhancerAction({ context, request }: ActionFunctionArgs) {
const { message, model, provider, apiKeys } = await request.json<{ const { message, model, provider } = await request.json<{
message: string; message: string;
model: string; model: string;
provider: ProviderInfo; provider: ProviderInfo;
@ -36,9 +56,17 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
}); });
} }
const cookieHeader = request.headers.get('Cookie');
// Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const providerSettings: Record<string, IProviderSetting> = JSON.parse(
parseCookies(cookieHeader || '').providers || '{}',
);
try { try {
const result = await streamText( const result = await streamText({
[ messages: [
{ {
role: 'user', role: 'user',
content: content:
@ -73,10 +101,10 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
`, `,
}, },
], ],
context.cloudflare.env, env: context.cloudflare.env,
undefined,
apiKeys, apiKeys,
); providerSettings,
});
const transformStream = new TransformStream({ const transformStream = new TransformStream({
transform(chunk, controller) { transform(chunk, controller) {

View File

@ -3,9 +3,17 @@ import type { ModelInfo } from '~/utils/types';
export type ProviderInfo = { export type ProviderInfo = {
staticModels: ModelInfo[]; staticModels: ModelInfo[];
name: string; name: string;
getDynamicModels?: (apiKeys?: Record<string, string>) => Promise<ModelInfo[]>; getDynamicModels?: (apiKeys?: Record<string, string>, providerSettings?: IProviderSetting) => Promise<ModelInfo[]>;
getApiKeyLink?: string; getApiKeyLink?: string;
labelForGetApiKey?: string; labelForGetApiKey?: string;
icon?: string; icon?: string;
isEnabled?: boolean; };
export interface IProviderSetting {
enabled?: boolean;
baseUrl?: string;
}
export type IProviderConfig = ProviderInfo & {
settings: IProviderSetting;
}; };

View File

@ -1,6 +1,6 @@
import Cookies from 'js-cookie'; import Cookies from 'js-cookie';
import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types'; import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
import type { ProviderInfo } from '~/types/model'; import type { ProviderInfo, IProviderSetting } from '~/types/model';
export const WORK_DIR_NAME = 'project'; export const WORK_DIR_NAME = 'project';
export const WORK_DIR = `/home/${WORK_DIR_NAME}`; export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
@ -295,13 +295,16 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(
export let MODEL_LIST: ModelInfo[] = [...staticModels]; export let MODEL_LIST: ModelInfo[] = [...staticModels];
export async function getModelList(apiKeys: Record<string, string>) { export async function getModelList(
apiKeys: Record<string, string>,
providerSettings?: Record<string, IProviderSetting>,
) {
MODEL_LIST = [ MODEL_LIST = [
...( ...(
await Promise.all( await Promise.all(
PROVIDER_LIST.filter( PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels, (p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels(apiKeys)), ).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name])),
) )
).flat(), ).flat(),
...staticModels, ...staticModels,
@ -309,9 +312,9 @@ export async function getModelList(apiKeys: Record<string, string>) {
return MODEL_LIST; return MODEL_LIST;
} }
async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> { async function getTogetherModels(apiKeys?: Record<string, string>, settings?: IProviderSetting): Promise<ModelInfo[]> {
try { try {
const baseUrl = import.meta.env.TOGETHER_API_BASE_URL || ''; const baseUrl = settings?.baseUrl || import.meta.env.TOGETHER_API_BASE_URL || '';
const provider = 'Together'; const provider = 'Together';
if (!baseUrl) { if (!baseUrl) {
@ -350,8 +353,8 @@ async function getTogetherModels(apiKeys?: Record<string, string>): Promise<Mode
} }
} }
const getOllamaBaseUrl = () => { const getOllamaBaseUrl = (settings?: IProviderSetting) => {
const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434'; const defaultBaseUrl = settings?.baseUrl || import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
// Check if we're in the browser // Check if we're in the browser
if (typeof window !== 'undefined') { if (typeof window !== 'undefined') {
@ -365,7 +368,7 @@ const getOllamaBaseUrl = () => {
return isDocker ? defaultBaseUrl.replace('localhost', 'host.docker.internal') : defaultBaseUrl; return isDocker ? defaultBaseUrl.replace('localhost', 'host.docker.internal') : defaultBaseUrl;
}; };
async function getOllamaModels(): Promise<ModelInfo[]> { async function getOllamaModels(apiKeys?: Record<string, string>, settings?: IProviderSetting): Promise<ModelInfo[]> {
/* /*
* if (typeof window === 'undefined') { * if (typeof window === 'undefined') {
* return []; * return [];
@ -373,7 +376,7 @@ async function getOllamaModels(): Promise<ModelInfo[]> {
*/ */
try { try {
const baseUrl = getOllamaBaseUrl(); const baseUrl = getOllamaBaseUrl(settings);
const response = await fetch(`${baseUrl}/api/tags`); const response = await fetch(`${baseUrl}/api/tags`);
const data = (await response.json()) as OllamaApiResponse; const data = (await response.json()) as OllamaApiResponse;
@ -389,20 +392,21 @@ async function getOllamaModels(): Promise<ModelInfo[]> {
} }
} }
async function getOpenAILikeModels(): Promise<ModelInfo[]> { async function getOpenAILikeModels(
apiKeys?: Record<string, string>,
settings?: IProviderSetting,
): Promise<ModelInfo[]> {
try { try {
const baseUrl = import.meta.env.OPENAI_LIKE_API_BASE_URL || ''; const baseUrl = settings?.baseUrl || import.meta.env.OPENAI_LIKE_API_BASE_URL || '';
if (!baseUrl) { if (!baseUrl) {
return []; return [];
} }
let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''; let apiKey = '';
const apikeys = JSON.parse(Cookies.get('apiKeys') || '{}'); if (apiKeys && apiKeys.OpenAILike) {
apiKey = apiKeys.OpenAILike;
if (apikeys && apikeys.OpenAILike) {
apiKey = apikeys.OpenAILike;
} }
const response = await fetch(`${baseUrl}/models`, { const response = await fetch(`${baseUrl}/models`, {
@ -456,13 +460,13 @@ async function getOpenRouterModels(): Promise<ModelInfo[]> {
})); }));
} }
async function getLMStudioModels(): Promise<ModelInfo[]> { async function getLMStudioModels(_apiKeys?: Record<string, string>, settings?: IProviderSetting): Promise<ModelInfo[]> {
if (typeof window === 'undefined') { if (typeof window === 'undefined') {
return []; return [];
} }
try { try {
const baseUrl = import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234'; const baseUrl = settings?.baseUrl || import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234';
const response = await fetch(`${baseUrl}/v1/models`); const response = await fetch(`${baseUrl}/v1/models`);
const data = (await response.json()) as any; const data = (await response.json()) as any;
@ -477,7 +481,7 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
} }
} }
async function initializeModelList(): Promise<ModelInfo[]> { async function initializeModelList(providerSettings?: Record<string, IProviderSetting>): Promise<ModelInfo[]> {
let apiKeys: Record<string, string> = {}; let apiKeys: Record<string, string> = {};
try { try {
@ -498,7 +502,7 @@ async function initializeModelList(): Promise<ModelInfo[]> {
await Promise.all( await Promise.all(
PROVIDER_LIST.filter( PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels, (p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels(apiKeys)), ).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name])),
) )
).flat(), ).flat(),
...staticModels, ...staticModels,

View File

@ -26,12 +26,3 @@ export interface ModelInfo {
provider: string; provider: string;
maxTokenAllowed: number; maxTokenAllowed: number;
} }
export interface ProviderInfo {
staticModels: ModelInfo[];
name: string;
getDynamicModels?: () => Promise<ModelInfo[]>;
getApiKeyLink?: string;
labelForGetApiKey?: string;
icon?: string;
}