diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 95f442067..079b31bae 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, ConfigDict from typing import Optional, List from utils.utils import get_verified_user, get_current_user, get_admin_user -from config import SRC_LOG_LEVELS, ENV +from config import SRC_LOG_LEVELS, ENV, MODEL_CONFIG from constants import MESSAGES import os @@ -67,6 +67,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file: app.state.ENABLE = ENABLE_LITELLM app.state.CONFIG = litellm_config +app.state.MODEL_CONFIG = MODEL_CONFIG.get("litellm", []) # Global variable to store the subprocess reference background_process = None @@ -238,6 +239,8 @@ async def get_models(user=Depends(get_current_user)): ) ) + for model in data["data"]: + add_custom_info_to_model(model) return data except Exception as e: @@ -258,6 +261,14 @@ async def get_models(user=Depends(get_current_user)): "object": "model", "created": int(time.time()), "owned_by": "openai", + "custom_info": next( + ( + item + for item in app.state.MODEL_CONFIG + if item["name"] == model["model_name"] + ), + {}, + ), } for model in app.state.CONFIG["model_list"] ], @@ -270,6 +281,12 @@ async def get_models(user=Depends(get_current_user)): } +def add_custom_info_to_model(model: dict): + model["custom_info"] = next( + (item for item in app.state.MODEL_CONFIG if item["name"] == model["id"]), {} + ) + + @app.get("/model/info") async def get_model_list(user=Depends(get_admin_user)): return {"data": app.state.CONFIG["model_list"]} diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 042d0336d..af4dab891 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,6 +46,7 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, + MODEL_CONFIG, ) from utils.misc import calculate_sha256 @@ -64,6 +65,7 @@ app.add_middleware( app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.MODEL_CONFIG = MODEL_CONFIG.get("ollama", []) app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -158,15 +160,26 @@ async def get_all_models(): models = { "models": merge_models_lists( - map(lambda response: response["models"] if response else None, responses) + map( + lambda response: (response["models"] if response else None), + responses, + ) ) } + for model in models["models"]: + add_custom_info_to_model(model) app.state.MODELS = {model["model"]: model for model in models["models"]} return models +def add_custom_info_to_model(model: dict): + model["custom_info"] = next( + (item for item in app.state.MODEL_CONFIG if item["name"] == model["model"]), {} + ) + + @app.get("/api/tags") @app.get("/api/tags/{url_idx}") async def get_ollama_tags( diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index b5d1e68d6..7bc401788 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -26,6 +26,7 @@ from config import ( CACHE_DIR, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, + MODEL_CONFIG, ) from typing import List, Optional @@ -47,6 +48,7 @@ app.add_middleware( app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.MODEL_CONFIG = MODEL_CONFIG.get("openai", []) app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -217,10 +219,19 @@ async def get_all_models(): ) } + for model in models["data"]: + add_custom_info_to_model(model) + log.info(f"models: {models}") app.state.MODELS = {model["id"]: model for model in models["data"]} - return models + return models + + +def add_custom_info_to_model(model: dict): + model["custom_info"] = next( + (item for item in app.state.MODEL_CONFIG if item["name"] == model["id"]), {} + ) @app.get("/models") diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index 0bad55a6a..c1322dce0 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -35,6 +35,19 @@ class SetDefaultSuggestionsForm(BaseModel): suggestions: List[PromptSuggestion] +class ModelConfig(BaseModel): + id: str + name: str + description: str + vision_capable: bool + + +class SetModelConfigForm(BaseModel): + ollama: List[ModelConfig] + litellm: List[ModelConfig] + openai: List[ModelConfig] + + ############################ # SetDefaultModels ############################ @@ -57,3 +70,14 @@ async def set_global_default_suggestions( data = form_data.model_dump() request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] return request.app.state.DEFAULT_PROMPT_SUGGESTIONS + + +@router.post("/models", response_model=SetModelConfigForm) +async def set_global_default_suggestions( + request: Request, + form_data: SetModelConfigForm, + user=Depends(get_admin_user), +): + data = form_data.model_dump() + request.app.state.MODEL_CONFIG = data + return request.app.state.MODEL_CONFIG diff --git a/backend/config.py b/backend/config.py index 5c6247a9f..2d8f601f3 100644 --- a/backend/config.py +++ b/backend/config.py @@ -424,6 +424,8 @@ WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" +MODEL_CONFIG = CONFIG_DATA.get("models", {"ollama": [], "litellm": [], "openai": []}) + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/main.py b/backend/main.py index 139819f7c..c45862de6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,6 +58,7 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, + MODEL_CONFIG, ) from constants import ERROR_MESSAGES diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts index 643146b73..23e9f62b8 100644 --- a/src/lib/apis/litellm/index.ts +++ b/src/lib/apis/litellm/index.ts @@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => { id: model.id, name: model.name ?? model.id, external: true, - source: 'LiteLLM' + source: 'LiteLLM', + custom_info: model.custom_info ?? {} })) .sort((a, b) => { return a.name.localeCompare(b.name); diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 41b6f9b6d..f99a1a92e 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -163,7 +163,12 @@ export const getOpenAIModels = async (token: string = '') => { return models ? models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) + .map((model) => ({ + id: model.id, + name: model.name ?? model.id, + external: true, + custom_info: model.custom_info ?? {} + })) .sort((a, b) => { return a.name.localeCompare(b.name); }) diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index f2a8bb19a..04e3fabb2 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -125,7 +125,7 @@ {#each $models.filter((model) => model.id) as model} {model.custom_info?.displayName ?? model.name} {/each} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 6711ea2b5..a43021fb6 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -1,7 +1,7 @@