diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py
index 95f442067..079b31bae 100644
--- a/backend/apps/litellm/main.py
+++ b/backend/apps/litellm/main.py
@@ -18,7 +18,7 @@ from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user
-from config import SRC_LOG_LEVELS, ENV
+from config import SRC_LOG_LEVELS, ENV, MODEL_CONFIG
from constants import MESSAGES
import os
@@ -67,6 +67,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config
+app.state.MODEL_CONFIG = MODEL_CONFIG.get("litellm", [])
# Global variable to store the subprocess reference
background_process = None
@@ -238,6 +239,8 @@ async def get_models(user=Depends(get_current_user)):
)
)
+ for model in data["data"]:
+ add_custom_info_to_model(model)
return data
except Exception as e:
@@ -258,6 +261,14 @@ async def get_models(user=Depends(get_current_user)):
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
+ "custom_info": next(
+ (
+ item
+ for item in app.state.MODEL_CONFIG
+ if item["name"] == model["model_name"]
+ ),
+ {},
+ ),
}
for model in app.state.CONFIG["model_list"]
],
@@ -270,6 +281,12 @@ async def get_models(user=Depends(get_current_user)):
}
+def add_custom_info_to_model(model: dict):
+ model["custom_info"] = next(
+ (item for item in app.state.MODEL_CONFIG if item["name"] == model["id"]), {}
+ )
+
+
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py
index 042d0336d..af4dab891 100644
--- a/backend/apps/ollama/main.py
+++ b/backend/apps/ollama/main.py
@@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
UPLOAD_DIR,
+ MODEL_CONFIG,
)
from utils.misc import calculate_sha256
@@ -64,6 +65,7 @@ app.add_middleware(
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+app.state.MODEL_CONFIG = MODEL_CONFIG.get("ollama", [])
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
@@ -158,15 +160,26 @@ async def get_all_models():
models = {
"models": merge_models_lists(
- map(lambda response: response["models"] if response else None, responses)
+ map(
+ lambda response: (response["models"] if response else None),
+ responses,
+ )
)
}
+ for model in models["models"]:
+ add_custom_info_to_model(model)
app.state.MODELS = {model["model"]: model for model in models["models"]}
return models
+def add_custom_info_to_model(model: dict):
+ model["custom_info"] = next(
+ (item for item in app.state.MODEL_CONFIG if item["name"] == model["model"]), {}
+ )
+
+
@app.get("/api/tags")
@app.get("/api/tags/{url_idx}")
async def get_ollama_tags(
diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py
index b5d1e68d6..7bc401788 100644
--- a/backend/apps/openai/main.py
+++ b/backend/apps/openai/main.py
@@ -26,6 +26,7 @@ from config import (
CACHE_DIR,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
+ MODEL_CONFIG,
)
from typing import List, Optional
@@ -47,6 +48,7 @@ app.add_middleware(
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+app.state.MODEL_CONFIG = MODEL_CONFIG.get("openai", [])
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
@@ -217,10 +219,19 @@ async def get_all_models():
)
}
+ for model in models["data"]:
+ add_custom_info_to_model(model)
+
log.info(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
- return models
+ return models
+
+
+def add_custom_info_to_model(model: dict):
+ model["custom_info"] = next(
+ (item for item in app.state.MODEL_CONFIG if item["name"] == model["id"]), {}
+ )
@app.get("/models")
diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py
index 0bad55a6a..c1322dce0 100644
--- a/backend/apps/web/routers/configs.py
+++ b/backend/apps/web/routers/configs.py
@@ -35,6 +35,19 @@ class SetDefaultSuggestionsForm(BaseModel):
suggestions: List[PromptSuggestion]
+class ModelConfig(BaseModel):
+ id: str
+ name: str
+ description: str
+ vision_capable: bool
+
+
+class SetModelConfigForm(BaseModel):
+ ollama: List[ModelConfig]
+ litellm: List[ModelConfig]
+ openai: List[ModelConfig]
+
+
############################
# SetDefaultModels
############################
@@ -57,3 +70,14 @@ async def set_global_default_suggestions(
data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
+
+
+@router.post("/models", response_model=SetModelConfigForm)
+async def set_global_default_suggestions(
+ request: Request,
+ form_data: SetModelConfigForm,
+ user=Depends(get_admin_user),
+):
+ data = form_data.model_dump()
+ request.app.state.MODEL_CONFIG = data
+ return request.app.state.MODEL_CONFIG
diff --git a/backend/config.py b/backend/config.py
index 5c6247a9f..2d8f601f3 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -424,6 +424,8 @@ WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
+MODEL_CONFIG = CONFIG_DATA.get("models", {"ollama": [], "litellm": [], "openai": []})
+
####################################
# WEBUI_SECRET_KEY
####################################
diff --git a/backend/main.py b/backend/main.py
index 139819f7c..c45862de6 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -58,6 +58,7 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
+ MODEL_CONFIG,
)
from constants import ERROR_MESSAGES
diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts
index 643146b73..23e9f62b8 100644
--- a/src/lib/apis/litellm/index.ts
+++ b/src/lib/apis/litellm/index.ts
@@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => {
id: model.id,
name: model.name ?? model.id,
external: true,
- source: 'LiteLLM'
+ source: 'LiteLLM',
+ custom_info: model.custom_info ?? {}
}))
.sort((a, b) => {
return a.name.localeCompare(b.name);
diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts
index 41b6f9b6d..f99a1a92e 100644
--- a/src/lib/apis/openai/index.ts
+++ b/src/lib/apis/openai/index.ts
@@ -163,7 +163,12 @@ export const getOpenAIModels = async (token: string = '') => {
return models
? models
- .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true }))
+ .map((model) => ({
+ id: model.id,
+ name: model.name ?? model.id,
+ external: true,
+ custom_info: model.custom_info ?? {}
+ }))
.sort((a, b) => {
return a.name.localeCompare(b.name);
})
diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte
index f2a8bb19a..04e3fabb2 100644
--- a/src/lib/components/admin/Settings/Users.svelte
+++ b/src/lib/components/admin/Settings/Users.svelte
@@ -125,7 +125,7 @@
{#each $models.filter((model) => model.id) as model}
{model.custom_info?.displayName ?? model.name}
{/each}
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte
index 6711ea2b5..a43021fb6 100644
--- a/src/lib/components/chat/MessageInput.svelte
+++ b/src/lib/components/chat/MessageInput.svelte
@@ -1,7 +1,7 @@