diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index 57a153d..b17baf9 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -17,7 +17,6 @@ import Cookies from 'js-cookie'; import * as Tooltip from '@radix-ui/react-tooltip'; import styles from './BaseChat.module.scss'; -import type { ProviderInfo } from '~/utils/types'; import { ExportChatButton } from '~/components/chat/chatExportAndImport/ExportChatButton'; import { ImportButtons } from '~/components/chat/chatExportAndImport/ImportButtons'; import { ExamplePrompts } from '~/components/chat/ExamplePrompts'; @@ -26,6 +25,7 @@ import GitCloneButton from './GitCloneButton'; import FilePreview from './FilePreview'; import { ModelSelector } from '~/components/chat/ModelSelector'; import { SpeechRecognitionButton } from '~/components/chat/SpeechRecognition'; +import type { IProviderSetting, ProviderInfo } from '~/types/model'; const TEXTAREA_MIN_HEIGHT = 76; @@ -131,7 +131,26 @@ export const BaseChat = React.forwardRef( Cookies.remove('apiKeys'); } - initializeModelList().then((modelList) => { + let providerSettings: Record | undefined = undefined; + + try { + const savedProviderSettings = Cookies.get('providers'); + + if (savedProviderSettings) { + const parsedProviderSettings = JSON.parse(savedProviderSettings); + + if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) { + providerSettings = parsedProviderSettings; + } + } + } catch (error) { + console.error('Error loading Provider Settings from cookies:', error); + + // Clear invalid cookie data + Cookies.remove('providers'); + } + + initializeModelList(providerSettings).then((modelList) => { setModelList(modelList); }); diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 7c67a75..cd651cb 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -17,9 +17,9 @@ import { cubicEasingFn } from '~/utils/easings'; import { createScopedLogger, renderLogger } from '~/utils/logger'; import { BaseChat } from './BaseChat'; import Cookies from 'js-cookie'; -import type { ProviderInfo } from '~/utils/types'; import { debounce } from '~/utils/debounce'; import { useSettings } from '~/lib/hooks/useSettings'; +import type { ProviderInfo } from '~/types/model'; const toastAnimation = cssTransition({ enter: 'animated fadeInRight', diff --git a/app/components/settings/providers/ProvidersTab.tsx b/app/components/settings/providers/ProvidersTab.tsx index 0b87959..309afb8 100644 --- a/app/components/settings/providers/ProvidersTab.tsx +++ b/app/components/settings/providers/ProvidersTab.tsx @@ -1,7 +1,8 @@ import React, { useEffect, useState } from 'react'; import { Switch } from '~/components/ui/Switch'; import { useSettings } from '~/lib/hooks/useSettings'; -import { LOCAL_PROVIDERS, URL_CONFIGURABLE_PROVIDERS, type IProviderConfig } from '~/lib/stores/settings'; +import { LOCAL_PROVIDERS, URL_CONFIGURABLE_PROVIDERS } from '~/lib/stores/settings'; +import type { IProviderConfig } from '~/types/model'; export default function ProvidersTab() { const { providers, updateProviderSettings, isLocalModel } = useSettings(); diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts index ecbcd64..2588c2b 100644 --- a/app/lib/.server/llm/model.ts +++ b/app/lib/.server/llm/model.ts @@ -11,6 +11,7 @@ import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { createMistral } from '@ai-sdk/mistral'; import { createCohere } from '@ai-sdk/cohere'; import type { LanguageModelV1 } from 'ai'; +import type { IProviderSetting } from '~/types/model'; export const DEFAULT_NUM_CTX = process.env.DEFAULT_NUM_CTX ? parseInt(process.env.DEFAULT_NUM_CTX, 10) : 32768; @@ -127,14 +128,20 @@ export function getXAIModel(apiKey: OptionalApiKey, model: string) { return openai(model); } -export function getModel(provider: string, model: string, env: Env, apiKeys?: Record) { +export function getModel( + provider: string, + model: string, + env: Env, + apiKeys?: Record, + providerSettings?: Record, +) { /* * let apiKey; // Declare first * let baseURL; */ const apiKey = getAPIKey(env, provider, apiKeys); // Then assign - const baseURL = getBaseURL(env, provider); + const baseURL = providerSettings?.[provider].baseUrl || getBaseURL(env, provider); switch (provider) { case 'Anthropic': diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts index f408ba2..52271f0 100644 --- a/app/lib/.server/llm/stream-text.ts +++ b/app/lib/.server/llm/stream-text.ts @@ -3,6 +3,7 @@ import { getModel } from '~/lib/.server/llm/model'; import { MAX_TOKENS } from './constants'; import { getSystemPrompt } from './prompts'; import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants'; +import type { IProviderSetting } from '~/types/model'; interface ToolResult { toolCallId: string; @@ -58,15 +59,17 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid return { model, provider, content: cleanedContent }; } -export async function streamText( - messages: Messages, - env: Env, - options?: StreamingOptions, - apiKeys?: Record, -) { +export async function streamText(props: { + messages: Messages; + env: Env; + options?: StreamingOptions; + apiKeys?: Record; + providerSettings?: Record; +}) { + const { messages, env, options, apiKeys, providerSettings } = props; let currentModel = DEFAULT_MODEL; let currentProvider = DEFAULT_PROVIDER.name; - const MODEL_LIST = await getModelList(apiKeys || {}); + const MODEL_LIST = await getModelList(apiKeys || {}, providerSettings); const processedMessages = messages.map((message) => { if (message.role === 'user') { const { model, provider, content } = extractPropertiesFromMessage(message); @@ -88,7 +91,7 @@ export async function streamText( const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS; return _streamText({ - model: getModel(currentProvider, currentModel, env, apiKeys) as any, + model: getModel(currentProvider, currentModel, env, apiKeys, providerSettings) as any, system: getSystemPrompt(), maxTokens: dynamicMaxTokens, messages: convertToCoreMessages(processedMessages as any), diff --git a/app/lib/hooks/useSettings.tsx b/app/lib/hooks/useSettings.tsx index 9b63430..531e481 100644 --- a/app/lib/hooks/useSettings.tsx +++ b/app/lib/hooks/useSettings.tsx @@ -1,14 +1,8 @@ import { useStore } from '@nanostores/react'; -import { - isDebugMode, - isLocalModelsEnabled, - LOCAL_PROVIDERS, - providersStore, - type IProviderSetting, -} from '~/lib/stores/settings'; +import { isDebugMode, isLocalModelsEnabled, LOCAL_PROVIDERS, providersStore } from '~/lib/stores/settings'; import { useCallback, useEffect, useState } from 'react'; import Cookies from 'js-cookie'; -import type { ProviderInfo } from '~/utils/types'; +import type { IProviderSetting, ProviderInfo } from '~/types/model'; export function useSettings() { const providers = useStore(providersStore); diff --git a/app/lib/stores/settings.ts b/app/lib/stores/settings.ts index b6dbc06..31564e6 100644 --- a/app/lib/stores/settings.ts +++ b/app/lib/stores/settings.ts @@ -1,7 +1,7 @@ import { atom, map } from 'nanostores'; import { workbenchStore } from './workbench'; -import type { ProviderInfo } from '~/utils/types'; import { PROVIDER_LIST } from '~/utils/constants'; +import type { IProviderConfig } from '~/types/model'; export interface Shortcut { key: string; @@ -17,14 +17,6 @@ export interface Shortcuts { toggleTerminal: Shortcut; } -export interface IProviderSetting { - enabled?: boolean; - baseUrl?: string; -} -export type IProviderConfig = ProviderInfo & { - settings: IProviderSetting; -}; - export const URL_CONFIGURABLE_PROVIDERS = ['Ollama', 'LMStudio', 'OpenAILike']; export const LOCAL_PROVIDERS = ['OpenAILike', 'LMStudio', 'Ollama']; diff --git a/app/routes/api.chat.ts b/app/routes/api.chat.ts index 0073274..9edf1af 100644 --- a/app/routes/api.chat.ts +++ b/app/routes/api.chat.ts @@ -3,6 +3,7 @@ import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts'; import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; import SwitchableStream from '~/lib/.server/llm/switchable-stream'; +import type { IProviderSetting } from '~/types/model'; export async function action(args: ActionFunctionArgs) { return chatAction(args); @@ -38,6 +39,9 @@ async function chatAction({ context, request }: ActionFunctionArgs) { // Parse the cookie's value (returns an object or null if no cookie exists) const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); + const providerSettings: Record = JSON.parse( + parseCookies(cookieHeader || '').providers || '{}', + ); const stream = new SwitchableStream(); @@ -60,13 +64,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) { messages.push({ role: 'assistant', content }); messages.push({ role: 'user', content: CONTINUE_PROMPT }); - const result = await streamText(messages, context.cloudflare.env, options, apiKeys); + const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, providerSettings }); return stream.switchSource(result.toAIStream()); }, }; - const result = await streamText(messages, context.cloudflare.env, options, apiKeys); + const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, providerSettings }); stream.switchSource(result.toAIStream()); diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts index 0738ae4..cc51116 100644 --- a/app/routes/api.enhancer.ts +++ b/app/routes/api.enhancer.ts @@ -2,7 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { StreamingTextResponse, parseStreamPart } from 'ai'; import { streamText } from '~/lib/.server/llm/stream-text'; import { stripIndents } from '~/utils/stripIndent'; -import type { ProviderInfo } from '~/types/model'; +import type { IProviderSetting, ProviderInfo } from '~/types/model'; const encoder = new TextEncoder(); const decoder = new TextDecoder(); @@ -11,8 +11,28 @@ export async function action(args: ActionFunctionArgs) { return enhancerAction(args); } +function parseCookies(cookieHeader: string) { + const cookies: any = {}; + + // Split the cookie string by semicolons and spaces + const items = cookieHeader.split(';').map((cookie) => cookie.trim()); + + items.forEach((item) => { + const [name, ...rest] = item.split('='); + + if (name && rest) { + // Decode the name and value, and join value parts in case it contains '=' + const decodedName = decodeURIComponent(name.trim()); + const decodedValue = decodeURIComponent(rest.join('=').trim()); + cookies[decodedName] = decodedValue; + } + }); + + return cookies; +} + async function enhancerAction({ context, request }: ActionFunctionArgs) { - const { message, model, provider, apiKeys } = await request.json<{ + const { message, model, provider } = await request.json<{ message: string; model: string; provider: ProviderInfo; @@ -36,9 +56,17 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { }); } + const cookieHeader = request.headers.get('Cookie'); + + // Parse the cookie's value (returns an object or null if no cookie exists) + const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); + const providerSettings: Record = JSON.parse( + parseCookies(cookieHeader || '').providers || '{}', + ); + try { - const result = await streamText( - [ + const result = await streamText({ + messages: [ { role: 'user', content: @@ -73,10 +101,10 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { `, }, ], - context.cloudflare.env, - undefined, + env: context.cloudflare.env, apiKeys, - ); + providerSettings, + }); const transformStream = new TransformStream({ transform(chunk, controller) { diff --git a/app/types/model.ts b/app/types/model.ts index c6c58d7..3bfbfde 100644 --- a/app/types/model.ts +++ b/app/types/model.ts @@ -3,9 +3,17 @@ import type { ModelInfo } from '~/utils/types'; export type ProviderInfo = { staticModels: ModelInfo[]; name: string; - getDynamicModels?: (apiKeys?: Record) => Promise; + getDynamicModels?: (apiKeys?: Record, providerSettings?: IProviderSetting) => Promise; getApiKeyLink?: string; labelForGetApiKey?: string; icon?: string; - isEnabled?: boolean; +}; + +export interface IProviderSetting { + enabled?: boolean; + baseUrl?: string; +} + +export type IProviderConfig = ProviderInfo & { + settings: IProviderSetting; }; diff --git a/app/utils/constants.ts b/app/utils/constants.ts index eb90f29..ffedee6 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -1,6 +1,6 @@ import Cookies from 'js-cookie'; import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types'; -import type { ProviderInfo } from '~/types/model'; +import type { ProviderInfo, IProviderSetting } from '~/types/model'; export const WORK_DIR_NAME = 'project'; export const WORK_DIR = `/home/${WORK_DIR_NAME}`; @@ -295,13 +295,16 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat( export let MODEL_LIST: ModelInfo[] = [...staticModels]; -export async function getModelList(apiKeys: Record) { +export async function getModelList( + apiKeys: Record, + providerSettings?: Record, +) { MODEL_LIST = [ ...( await Promise.all( PROVIDER_LIST.filter( (p): p is ProviderInfo & { getDynamicModels: () => Promise } => !!p.getDynamicModels, - ).map((p) => p.getDynamicModels(apiKeys)), + ).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name])), ) ).flat(), ...staticModels, @@ -309,9 +312,9 @@ export async function getModelList(apiKeys: Record) { return MODEL_LIST; } -async function getTogetherModels(apiKeys?: Record): Promise { +async function getTogetherModels(apiKeys?: Record, settings?: IProviderSetting): Promise { try { - const baseUrl = import.meta.env.TOGETHER_API_BASE_URL || ''; + const baseUrl = settings?.baseUrl || import.meta.env.TOGETHER_API_BASE_URL || ''; const provider = 'Together'; if (!baseUrl) { @@ -350,8 +353,8 @@ async function getTogetherModels(apiKeys?: Record): Promise { - const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434'; +const getOllamaBaseUrl = (settings?: IProviderSetting) => { + const defaultBaseUrl = settings?.baseUrl || import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434'; // Check if we're in the browser if (typeof window !== 'undefined') { @@ -365,7 +368,7 @@ const getOllamaBaseUrl = () => { return isDocker ? defaultBaseUrl.replace('localhost', 'host.docker.internal') : defaultBaseUrl; }; -async function getOllamaModels(): Promise { +async function getOllamaModels(apiKeys?: Record, settings?: IProviderSetting): Promise { /* * if (typeof window === 'undefined') { * return []; @@ -373,7 +376,7 @@ async function getOllamaModels(): Promise { */ try { - const baseUrl = getOllamaBaseUrl(); + const baseUrl = getOllamaBaseUrl(settings); const response = await fetch(`${baseUrl}/api/tags`); const data = (await response.json()) as OllamaApiResponse; @@ -389,20 +392,21 @@ async function getOllamaModels(): Promise { } } -async function getOpenAILikeModels(): Promise { +async function getOpenAILikeModels( + apiKeys?: Record, + settings?: IProviderSetting, +): Promise { try { - const baseUrl = import.meta.env.OPENAI_LIKE_API_BASE_URL || ''; + const baseUrl = settings?.baseUrl || import.meta.env.OPENAI_LIKE_API_BASE_URL || ''; if (!baseUrl) { return []; } - let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? ''; + let apiKey = ''; - const apikeys = JSON.parse(Cookies.get('apiKeys') || '{}'); - - if (apikeys && apikeys.OpenAILike) { - apiKey = apikeys.OpenAILike; + if (apiKeys && apiKeys.OpenAILike) { + apiKey = apiKeys.OpenAILike; } const response = await fetch(`${baseUrl}/models`, { @@ -456,13 +460,13 @@ async function getOpenRouterModels(): Promise { })); } -async function getLMStudioModels(): Promise { +async function getLMStudioModels(_apiKeys?: Record, settings?: IProviderSetting): Promise { if (typeof window === 'undefined') { return []; } try { - const baseUrl = import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234'; + const baseUrl = settings?.baseUrl || import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234'; const response = await fetch(`${baseUrl}/v1/models`); const data = (await response.json()) as any; @@ -477,7 +481,7 @@ async function getLMStudioModels(): Promise { } } -async function initializeModelList(): Promise { +async function initializeModelList(providerSettings?: Record): Promise { let apiKeys: Record = {}; try { @@ -498,7 +502,7 @@ async function initializeModelList(): Promise { await Promise.all( PROVIDER_LIST.filter( (p): p is ProviderInfo & { getDynamicModels: () => Promise } => !!p.getDynamicModels, - ).map((p) => p.getDynamicModels(apiKeys)), + ).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name])), ) ).flat(), ...staticModels, diff --git a/app/utils/types.ts b/app/utils/types.ts index 8742891..1fa253f 100644 --- a/app/utils/types.ts +++ b/app/utils/types.ts @@ -26,12 +26,3 @@ export interface ModelInfo { provider: string; maxTokenAllowed: number; } - -export interface ProviderInfo { - staticModels: ModelInfo[]; - name: string; - getDynamicModels?: () => Promise; - getApiKeyLink?: string; - labelForGetApiKey?: string; - icon?: string; -}