diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 079b31bae..8161b15e5 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -283,7 +283,7 @@ async def get_models(user=Depends(get_current_user)): def add_custom_info_to_model(model: dict): model["custom_info"] = next( - (item for item in app.state.MODEL_CONFIG if item["name"] == model["id"]), {} + (item for item in app.state.MODEL_CONFIG if item["id"] == model["id"]), {} ) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index af4dab891..ce83b249b 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -176,7 +176,7 @@ async def get_all_models(): def add_custom_info_to_model(model: dict): model["custom_info"] = next( - (item for item in app.state.MODEL_CONFIG if item["name"] == model["model"]), {} + (item for item in app.state.MODEL_CONFIG if item["id"] == model["model"]), {} ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 7bc401788..2009dabbc 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -230,7 +230,7 @@ async def get_all_models(): def add_custom_info_to_model(model: dict): model["custom_info"] = next( - (item for item in app.state.MODEL_CONFIG if item["name"] == model["id"]), {} + (item for item in app.state.MODEL_CONFIG if item["id"] == model["id"]), {} ) diff --git a/backend/main.py b/backend/main.py index 33b4f5f58..50dcd3274 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,6 +58,7 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, + MODEL_CONFIG, ) from constants import ERROR_MESSAGES @@ -97,6 +98,8 @@ app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.MODEL_CONFIG = MODEL_CONFIG + app.state.WEBHOOK_URL = WEBHOOK_URL origins = ["*"] @@ -311,12 +314,19 @@ async def update_model_config( litellm_app.state.MODEL_CONFIG = data.get("litellm", []) - return { + app.state.MODEL_CONFIG = { "ollama": ollama_app.state.MODEL_CONFIG, "openai": openai_app.state.MODEL_CONFIG, "litellm": litellm_app.state.MODEL_CONFIG, } + return {"models": app.state.MODEL_CONFIG} + + +@app.get("/api/config/models") +async def get_model_config(user=Depends(get_admin_user)): + return {"models": app.state.MODEL_CONFIG} + @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a610f7210..9d0db99c0 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -196,3 +196,71 @@ export const updateWebhookUrl = async (token: string, url: string) => { return res.url; }; + +export const getModelConfig = async (token: string): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res.models; +}; + +export interface ModelConfig { + id: string; + name?: string; + description?: string; + vision_capable?: boolean; +} + +export interface GlobalModelConfig { + ollama: ModelConfig[]; + litellm: ModelConfig[]; + openai: ModelConfig[]; +} + +export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify(config) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index 04e3fabb2..6f4144634 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -125,7 +125,7 @@ {#each $models.filter((model) => model.id) as model} {model.custom_info?.name ?? model.name} {/each} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 1a19e4bb4..70fd68494 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -522,7 +522,7 @@ />
Talking to {selectedModel.custom_info?.displayName ?? selectedModel.name} + >{selectedModel.custom_info?.name ?? selectedModel.name}
diff --git a/src/lib/components/chat/MessageInput/Models.svelte b/src/lib/components/chat/MessageInput/Models.svelte index ed8da4e78..f7b7dfb1b 100644 --- a/src/lib/components/chat/MessageInput/Models.svelte +++ b/src/lib/components/chat/MessageInput/Models.svelte @@ -22,10 +22,10 @@ $: filteredModels = $models .filter((p) => - (p.custom_info?.displayName ?? p.name).includes(prompt.split(' ')?.at(0)?.substring(1) ?? '') + (p.custom_info?.name ?? p.name).includes(prompt.split(' ')?.at(0)?.substring(1) ?? '') ) .sort((a, b) => - (a.custom_info?.displayName ?? a.name).localeCompare(b.custom_info?.displayName ?? b.name) + (a.custom_info?.name ?? a.name).localeCompare(b.custom_info?.name ?? b.name) ); $: if (prompt) { @@ -160,7 +160,7 @@ on:focus={() => {}} >
- {model.custom_info?.displayName ?? model.name} + {model.custom_info?.name ?? model.name}
@@ -1129,6 +1204,148 @@ {/if}
+
+
+ +
+
+
+
+
+
{$i18n.t('Manage Model Information')}
+ +
+
+ + {#if showModelInfo} +
+
+
{$i18n.t('Current Models')}
+
+ +
+
+ +
+ +
+ + {#if selectedModelId} +
+
{$i18n.t('Model Display Name')}
+
+
+ +
+ + +
+
+ +
+
{$i18n.t('Model Description')}
+ +
+
+