From 0809eb79b8efa5f48195ca0cd9db8cb054ad1cd6 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 11 Nov 2024 21:18:51 -0800 Subject: [PATCH] refac: openai connections --- backend/open_webui/apps/ollama/main.py | 38 +- backend/open_webui/apps/openai/main.py | 198 +++++++-- backend/open_webui/config.py | 15 +- src/lib/apis/ollama/index.ts | 10 +- src/lib/apis/openai/index.ts | 55 ++- src/lib/components/admin/Settings.svelte | 2 +- .../admin/Settings/Connections.svelte | 385 ++++++------------ .../Connections/OpenAIConnection.svelte | 107 +++++ .../Connections/OpenAIConnectionModal.svelte | 339 +++++++++++++++ .../components/admin/Settings/Database.svelte | 31 -- .../components/admin/Settings/General.svelte | 6 +- .../components/admin/Settings/Users.svelte | 4 +- .../components/common/SensitiveInput.svelte | 3 +- 13 files changed, 814 insertions(+), 379 deletions(-) create mode 100644 src/lib/components/admin/Settings/Connections/OpenAIConnection.svelte create mode 100644 src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 05cd2d223..2038f4302 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -46,7 +46,11 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) app.add_middleware( CORSMiddleware, @@ -90,34 +94,26 @@ async def get_status(): @app.get("/config") async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + return { + "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + } class OllamaConfigForm(BaseModel): - enable_ollama_api: Optional[bool] = None + ENABLE_OLLAMA_API: Optional[bool] = None + OLLAMA_BASE_URLS: list[str] @app.post("/config/update") async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS - -@app.get("/urls") -async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} - - -class UrlUpdateForm(BaseModel): - urls: list[str] - - -@app.post("/urls/update") -async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.config.OLLAMA_BASE_URLS = form_data.urls - - log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} + return { + "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + } async def fetch_url(url): diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index ed6ad5cbb..ef79c205b 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -16,6 +16,7 @@ from open_webui.config import ( MODEL_FILTER_LIST, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, + OPENAI_API_CONFIGS, AppConfig, ) from open_webui.env import ( @@ -43,7 +44,11 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) app.add_middleware( @@ -62,6 +67,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS app.state.MODELS = {} @@ -77,48 +83,58 @@ async def check_url(request: Request, call_next): @app.get("/config") async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + return { + "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + } class OpenAIConfigForm(BaseModel): - enable_openai_api: Optional[bool] = None + ENABLE_OPENAI_API: Optional[bool] = None + OPENAI_API_BASE_URLS: list[str] + OPENAI_API_KEYS: list[str] + OPENAI_API_CONFIGS: dict @app.post("/config/update") async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS -class UrlsUpdateForm(BaseModel): - urls: list[str] + # Check if API KEYS length is same than API URLS length + if len(app.state.config.OPENAI_API_KEYS) != len( + app.state.config.OPENAI_API_BASE_URLS + ): + if len(app.state.config.OPENAI_API_KEYS) > len( + app.state.config.OPENAI_API_BASE_URLS + ): + app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ + : len(app.state.config.OPENAI_API_BASE_URLS) + ] + else: + app.state.config.OPENAI_API_KEYS += [""] * ( + len(app.state.config.OPENAI_API_BASE_URLS) + - len(app.state.config.OPENAI_API_KEYS) + ) + app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS -class KeysUpdateForm(BaseModel): - keys: list[str] + # Remove any extra configs + config_urls = app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + if url not in config_urls: + app.state.config.OPENAI_API_CONFIGS.pop(url, None) - -@app.get("/urls") -async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.post("/urls/update") -async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): - await get_all_models() - app.state.config.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.get("/keys") -async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} - - -@app.post("/keys/update") -async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.config.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} + return { + "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + } @app.post("/audio/speech") @@ -190,7 +206,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def fetch_url(url, key): +async def aiohttp_get(url, key): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: headers = {"Authorization": f"Bearer {key}"} @@ -248,12 +264,8 @@ def merge_models_lists(model_lists): return merged_list -def is_openai_api_disabled(): - return not app.state.config.ENABLE_OPENAI_API - - async def get_all_models_raw() -> list: - if is_openai_api_disabled(): + if not app.state.config.ENABLE_OPENAI_API: return [] # Check if API KEYS length is same than API URLS length @@ -269,12 +281,55 @@ async def get_all_models_raw() -> list: else: app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - tasks = [ - fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) - ] + tasks = [] + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + if url not in app.state.config.OPENAI_API_CONFIGS: + tasks.append( + aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + ) + else: + api_config = app.state.config.OPENAI_API_CONFIGS[url] + + enabled = api_config.get("enabled", True) + model_ids = api_config.get("model_ids", []) + + if enabled: + if len(model_ids) == 0: + tasks.append( + aiohttp_get( + f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] + ) + ) + else: + model_list = { + "object": "list", + "data": [ + { + "id": model_id, + "name": model_id, + "owned_by": "openai", + "openai": {"id": model_id}, + "urlIdx": idx, + } + for model_id in model_ids + ], + } + + tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) responses = await asyncio.gather(*tasks) + + for idx, response in enumerate(responses): + if response: + url = app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = app.state.config.OPENAI_API_CONFIGS[url] + + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + for model in response["data"]: + model["id"] = f"{prefix_id}.{model['id']}" + log.debug(f"get_all_models:responses() {responses}") return responses @@ -290,7 +345,7 @@ async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... async def get_all_models(raw=False) -> dict[str, list] | list: log.info("get_all_models()") - if is_openai_api_disabled(): + if not app.state.config.ENABLE_OPENAI_API: return [] if raw else {"data": []} responses = await get_all_models_raw() @@ -342,7 +397,6 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us r = None - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) async with aiohttp.ClientSession(timeout=timeout) as session: try: @@ -361,7 +415,8 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us if "api.openai.com" in url: # Filter models according to the specified conditions response_data["data"] = [ - model for model in response_data.get("data", []) + model + for model in response_data.get("data", []) if not any( name in model["id"] for name in [ @@ -381,7 +436,9 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") # Handle aiohttp-specific connection issues, timeout etc. - raise HTTPException(status_code=500, detail="Open WebUI: Server Connection Error") + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) except Exception as e: log.exception(f"Unexpected error: {e}") # Generic error handler in case parsing JSON or other steps fail @@ -389,6 +446,49 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us raise HTTPException(status_code=500, detail=error_detail) +class ConnectionVerificationForm(BaseModel): + url: str + key: str + + +@app.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + headers = {} + headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get(f"{url}/models", headers=headers) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + # Handle aiohttp-specific connection issues, timeout etc. + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + # Generic error handler in case parsing JSON or other steps fail + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) @app.post("/chat/completions") @@ -418,6 +518,14 @@ async def generate_chat_completion( model = app.state.MODELS[payload.get("model")] idx = model["urlIdx"] + api_config = app.state.config.OPENAI_API_CONFIGS.get( + app.state.config.OPENAI_API_BASE_URLS[idx], {} + ) + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + if "pipeline" in model and model.get("pipeline"): payload["user"] = { "name": user.name, diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index cd1a55e26..7da3b0d6f 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -607,6 +607,12 @@ OLLAMA_BASE_URLS = PersistentConfig( "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS ) +OLLAMA_API_CONFIGS = PersistentConfig( + "OLLAMA_API_CONFIGS", + "ollama.api_configs", + {}, +) + #################################### # OPENAI_API #################################### @@ -647,15 +653,20 @@ OPENAI_API_BASE_URLS = PersistentConfig( "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS ) -OPENAI_API_KEY = "" +OPENAI_API_CONFIGS = PersistentConfig( + "OPENAI_API_CONFIGS", + "openai.api_configs", + {}, +) +# Get the actual OpenAI API key based on the base URL +OPENAI_API_KEY = "" try: OPENAI_API_KEY = OPENAI_API_KEYS.value[ OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except Exception: pass - OPENAI_API_BASE_URL = "https://api.openai.com/v1" #################################### diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index d4e994312..fcb37f7a4 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -32,7 +32,13 @@ export const getOllamaConfig = async (token: string = '') => { return res; }; -export const updateOllamaConfig = async (token: string = '', enable_ollama_api: boolean) => { +type OllamaConfig = { + ENABLE_OLLAMA_API: boolean, + OLLAMA_BASE_URLS: string[], + OLLAMA_API_CONFIGS: object +} + +export const updateOllamaConfig = async (token: string = '', config: OllamaConfig) => { let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/config/update`, { @@ -43,7 +49,7 @@ export const updateOllamaConfig = async (token: string = '', enable_ollama_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_ollama_api: enable_ollama_api + ...config }) }) .then(async (res) => { diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 2bb11d12a..a9f581249 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -32,7 +32,17 @@ export const getOpenAIConfig = async (token: string = '') => { return res; }; -export const updateOpenAIConfig = async (token: string = '', enable_openai_api: boolean) => { + +type OpenAIConfig = { + ENABLE_OPENAI_API: boolean; + OPENAI_API_BASE_URLS: string[]; + OPENAI_API_KEYS: string[]; + OPENAI_API_CONFIGS: object; +} + + + +export const updateOpenAIConfig = async (token: string = '', config: OpenAIConfig) => { let error = null; const res = await fetch(`${OPENAI_API_BASE_URL}/config/update`, { @@ -43,7 +53,7 @@ export const updateOpenAIConfig = async (token: string = '', enable_openai_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_openai_api: enable_openai_api + ...config }) }) .then(async (res) => { @@ -99,6 +109,7 @@ export const getOpenAIUrls = async (token: string = '') => { return res.OPENAI_API_BASE_URLS; }; + export const updateOpenAIUrls = async (token: string = '', urls: string[]) => { let error = null; @@ -231,41 +242,43 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => { return res; }; -export const getOpenAIModelsDirect = async ( - base_url: string = 'https://api.openai.com/v1', - api_key: string = '' +export const verifyOpenAIConnection = async ( + token: string = '', + url: string = 'https://api.openai.com/v1', + key: string = '' ) => { let error = null; - const res = await fetch(`${base_url}/models`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${api_key}` + const res = await fetch( + `${OPENAI_API_BASE_URL}/verify`, + { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json', + + }, + body: JSON.stringify({ + url, + key + }) } - }) + ) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); }) .catch((err) => { - console.log(err); error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; - return null; + return []; }); if (error) { throw error; } - const models = Array.isArray(res) ? res : (res?.data ?? null); - - return models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) - .filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true)) - .sort((a, b) => { - return a.name.localeCompare(b.name); - }); + return res; }; export const generateOpenAIChatCompletion = async ( diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 7e1e489b1..20af84cc0 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -39,7 +39,7 @@ }); -
+
- import { models, user } from '$lib/stores'; + import { toast } from 'svelte-sonner'; import { createEventDispatcher, onMount, getContext, tick } from 'svelte'; const dispatch = createEventDispatcher(); - import { - getOllamaConfig, - getOllamaUrls, - getOllamaVersion, - updateOllamaConfig, - updateOllamaUrls - } from '$lib/apis/ollama'; - import { - getOpenAIConfig, - getOpenAIKeys, - getOpenAIModels, - getOpenAIUrls, - updateOpenAIConfig, - updateOpenAIKeys, - updateOpenAIUrls - } from '$lib/apis/openai'; + import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama'; + import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai'; + import { getModels as _getModels } from '$lib/apis'; + + import { models, user } from '$lib/stores'; - import { toast } from 'svelte-sonner'; import Switch from '$lib/components/common/Switch.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; - import { getModels as _getModels } from '$lib/apis'; - import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; - import Cog6 from '$lib/components/icons/Cog6.svelte'; + + import OpenAIConnection from './Connections/OpenAIConnection.svelte'; + import OpenAIConnectionModal from './Connections/OpenAIConnectionModal.svelte'; + import Plus from '$lib/components/icons/Plus.svelte'; const i18n = getContext('i18n'); @@ -38,126 +27,113 @@ // External let OLLAMA_BASE_URLS = ['']; + let OLLAMA_API_CONFIGS = {}; let OPENAI_API_KEYS = ['']; let OPENAI_API_BASE_URLS = ['']; + let OPENAI_API_CONFIGS = {}; + + let ENABLE_OPENAI_API: null | boolean = null; + let ENABLE_OLLAMA_API: null | boolean = null; let pipelineUrls = {}; - - let ENABLE_OPENAI_API = null; - let ENABLE_OLLAMA_API = null; - - const verifyOpenAIHandler = async (idx) => { - OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); - - OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS); - OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS); - - const res = await getOpenAIModels(localStorage.token, idx).catch((error) => { - toast.error(error); - return null; - }); - - if (res) { - toast.success($i18n.t('Server connection verified')); - if (res.pipelines) { - pipelineUrls[OPENAI_API_BASE_URLS[idx]] = true; - } - } - - await models.set(await getModels()); - }; - - const verifyOllamaHandler = async (idx) => { - OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) => - url.replace(/\/$/, '') - ); - - OLLAMA_BASE_URLS = await updateOllamaUrls(localStorage.token, OLLAMA_BASE_URLS); - - const res = await getOllamaVersion(localStorage.token, idx).catch((error) => { - toast.error(error); - return null; - }); - - if (res) { - toast.success($i18n.t('Server connection verified')); - } - - await models.set(await getModels()); - }; + let showAddOpenAIConnectionModal = false; const updateOpenAIHandler = async () => { - OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); + if (ENABLE_OPENAI_API !== null) { + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); + // Check if API KEYS length is same than API URLS length + if (OPENAI_API_KEYS.length !== OPENAI_API_BASE_URLS.length) { + // if there are more keys than urls, remove the extra keys + if (OPENAI_API_KEYS.length > OPENAI_API_BASE_URLS.length) { + OPENAI_API_KEYS = OPENAI_API_KEYS.slice(0, OPENAI_API_BASE_URLS.length); + } - // Check if API KEYS length is same than API URLS length - if (OPENAI_API_KEYS.length !== OPENAI_API_BASE_URLS.length) { - // if there are more keys than urls, remove the extra keys - if (OPENAI_API_KEYS.length > OPENAI_API_BASE_URLS.length) { - OPENAI_API_KEYS = OPENAI_API_KEYS.slice(0, OPENAI_API_BASE_URLS.length); - } - - // if there are more urls than keys, add empty keys - if (OPENAI_API_KEYS.length < OPENAI_API_BASE_URLS.length) { - const diff = OPENAI_API_BASE_URLS.length - OPENAI_API_KEYS.length; - for (let i = 0; i < diff; i++) { - OPENAI_API_KEYS.push(''); + // if there are more urls than keys, add empty keys + if (OPENAI_API_KEYS.length < OPENAI_API_BASE_URLS.length) { + const diff = OPENAI_API_BASE_URLS.length - OPENAI_API_KEYS.length; + for (let i = 0; i < diff; i++) { + OPENAI_API_KEYS.push(''); + } } } - } - OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS); - OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS); - await models.set(await getModels()); - }; - - const updateOllamaUrlsHandler = async () => { - OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) => - url.replace(/\/$/, '') - ); - - console.log(OLLAMA_BASE_URLS); - - if (OLLAMA_BASE_URLS.length === 0) { - ENABLE_OLLAMA_API = false; - await updateOllamaConfig(localStorage.token, ENABLE_OLLAMA_API); - - toast.info($i18n.t('Ollama API disabled')); - } else { - OLLAMA_BASE_URLS = await updateOllamaUrls(localStorage.token, OLLAMA_BASE_URLS); - - const ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => { + const res = await updateOpenAIConfig(localStorage.token, { + ENABLE_OPENAI_API: ENABLE_OPENAI_API, + OPENAI_API_BASE_URLS: OPENAI_API_BASE_URLS, + OPENAI_API_KEYS: OPENAI_API_KEYS, + OPENAI_API_CONFIGS: OPENAI_API_CONFIGS + }).catch((error) => { toast.error(error); - return null; }); - if (ollamaVersion) { - toast.success($i18n.t('Server connection verified')); + if (res) { + toast.success($i18n.t('OpenAI API settings updated')); await models.set(await getModels()); } } }; + const updateOllamaHandler = async () => { + if (ENABLE_OLLAMA_API !== null) { + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) => + url.replace(/\/$/, '') + ); + + console.log(OLLAMA_BASE_URLS); + + if (OLLAMA_BASE_URLS.length === 0) { + ENABLE_OLLAMA_API = false; + toast.info($i18n.t('Ollama API disabled')); + } + + const res = await updateOllamaConfig(localStorage.token, { + ENABLE_OLLAMA_API: ENABLE_OLLAMA_API, + OLLAMA_BASE_URLS: OLLAMA_BASE_URLS, + OLLAMA_API_CONFIGS: OLLAMA_API_CONFIGS + }).catch((error) => { + toast.error(error); + }); + + if (res) { + toast.success($i18n.t('Ollama API settings updated')); + await models.set(await getModels()); + } + } + }; + + const addOpenAIConnectionHandler = async (connection) => { + OPENAI_API_BASE_URLS = [...OPENAI_API_BASE_URLS, connection.url]; + OPENAI_API_KEYS = [...OPENAI_API_KEYS, connection.key]; + OPENAI_API_CONFIGS[connection.url] = connection.config; + + await updateOpenAIHandler(); + }; + onMount(async () => { if ($user.role === 'admin') { + let ollamaConfig = {}; + let openaiConfig = {}; + await Promise.all([ (async () => { - OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token); + ollamaConfig = await getOllamaConfig(localStorage.token); })(), (async () => { - OPENAI_API_BASE_URLS = await getOpenAIUrls(localStorage.token); - })(), - (async () => { - OPENAI_API_KEYS = await getOpenAIKeys(localStorage.token); + openaiConfig = await getOpenAIConfig(localStorage.token); })() ]); - const ollamaConfig = await getOllamaConfig(localStorage.token); - const openaiConfig = await getOpenAIConfig(localStorage.token); - ENABLE_OPENAI_API = openaiConfig.ENABLE_OPENAI_API; ENABLE_OLLAMA_API = ollamaConfig.ENABLE_OLLAMA_API; + OPENAI_API_BASE_URLS = openaiConfig.OPENAI_API_BASE_URLS; + OPENAI_API_KEYS = openaiConfig.OPENAI_API_KEYS; + OPENAI_API_CONFIGS = openaiConfig.OPENAI_API_CONFIGS; + + OLLAMA_BASE_URLS = ollamaConfig.OLLAMA_BASE_URLS; + OLLAMA_API_CONFIGS = ollamaConfig.OLLAMA_API_CONFIGS; + if (ENABLE_OPENAI_API) { OPENAI_API_BASE_URLS.forEach(async (url, idx) => { const res = await getOpenAIModels(localStorage.token, idx); @@ -165,16 +141,35 @@ pipelineUrls[url] = true; } }); + + for (const url of OPENAI_API_BASE_URLS) { + if (!OPENAI_API_CONFIGS[url]) { + OPENAI_API_CONFIGS[url] = {}; + } + } + } + + if (ENABLE_OLLAMA_API) { + for (const url of OLLAMA_BASE_URLS) { + if (!OLLAMA_API_CONFIGS[url]) { + OLLAMA_API_CONFIGS[url] = {}; + } + } } } }); + +
{ updateOpenAIHandler(); - updateOllamaUrlsHandler(); + updateOllamaHandler(); dispatch('save'); }} @@ -191,7 +186,7 @@ { - updateOpenAIConfig(localStorage.token, ENABLE_OPENAI_API); + updateOpenAIHandler(); }} />
@@ -202,149 +197,39 @@
-
+
{$i18n.t('Manage OpenAI API Connections')}
- + + +
-
+
{#each OPENAI_API_BASE_URLS as url, idx} -
- -
-
- - - {#if pipelineUrls[url]} -
- - - - - - - -
- {/if} -
- - -
-
- -
- - - - - - - - - -
-
+ { + updateOpenAIHandler(); + }} + onDelete={() => { + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter( + (url, urlIdx) => idx !== urlIdx + ); + OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); + }} + /> {/each}
@@ -435,7 +320,7 @@ + +
+
diff --git a/src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte b/src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte new file mode 100644 index 000000000..4c24898b9 --- /dev/null +++ b/src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte @@ -0,0 +1,339 @@ + + + +
+
+
+ {#if edit} + {$i18n.t('Edit Connection')} + {:else} + {$i18n.t('Add Connection')} + {/if} +
+ +
+ +
+
+ { + e.preventDefault(); + submitHandler(); + }} + > +
+
+
+
{$i18n.t('URL')}
+ +
+ +
+
+ + + + + +
+ + + +
+
+ +
+
+
{$i18n.t('Key')}
+ +
+ +
+
+ +
+
{$i18n.t('Prefix ID')}
+ +
+ + + +
+
+
+ +
+ +
+
+
{$i18n.t('Model IDs')}
+
+ + {#if modelIds.length > 0} +
+ {#each modelIds as modelId, modelIdx} +
+
+ {modelId} +
+
+ +
+
+ {/each} +
+ {:else} +
+ {$i18n.t('Leave empty to include all models from "{{URL}}/models" endpoint', { + URL: url + })} +
+ {/if} +
+ +
+ +
+ + +
+ +
+
+
+ +
+ {#if edit} + + {/if} + + +
+ +
+
+
+
diff --git a/src/lib/components/admin/Settings/Database.svelte b/src/lib/components/admin/Settings/Database.svelte index 3c376cb80..71fda7642 100644 --- a/src/lib/components/admin/Settings/Database.svelte +++ b/src/lib/components/admin/Settings/Database.svelte @@ -181,37 +181,6 @@
{/if} - -
- -
diff --git a/src/lib/components/admin/Settings/General.svelte b/src/lib/components/admin/Settings/General.svelte index b5d4e01e8..02399929d 100644 --- a/src/lib/components/admin/Settings/General.svelte +++ b/src/lib/components/admin/Settings/General.svelte @@ -69,7 +69,7 @@ -
+
@@ -91,7 +91,7 @@
-
+
@@ -115,7 +115,7 @@
-
+
diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index f0b245888..532553ea6 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -133,7 +133,7 @@
-
+
@@ -323,7 +323,7 @@
-
+
diff --git a/src/lib/components/common/SensitiveInput.svelte b/src/lib/components/common/SensitiveInput.svelte index 2f62c6635..cb6484a0b 100644 --- a/src/lib/components/common/SensitiveInput.svelte +++ b/src/lib/components/common/SensitiveInput.svelte @@ -4,7 +4,8 @@ export let required = true; export let readOnly = false; export let outerClassName = 'flex flex-1 bg-transparent'; - export let inputClassName = 'w-full text-sm py-0.5 bg-transparent outline-none'; + export let inputClassName = + 'w-full text-sm py-0.5 placeholder:text-gray-300 dark:placeholder:text-gray-700 bg-transparent outline-none'; export let showButtonClassName = 'pl-1.5 transition bg-transparent'; let show = false;