diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index bef91443a..2b771d5c6 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -242,8 +242,6 @@ async def get_models(user=Depends(get_current_user)): ) ) - for model in data["data"]: - add_custom_info_to_model(model) return data except Exception as e: @@ -284,12 +282,6 @@ async def get_models(user=Depends(get_current_user)): } -def add_custom_info_to_model(model: dict): - model["custom_info"] = next( - (item for item in app.state.MODEL_CONFIG if item.id == model["id"]), None - ) - - @app.get("/model/info") async def get_model_list(user=Depends(get_admin_user)): return {"data": app.state.CONFIG["model_list"]} diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 178cfa5fd..674ff50c4 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -67,8 +67,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.MODEL_CONFIG = Models.get_all_models() - app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS @@ -192,21 +190,12 @@ async def get_all_models(): else: models = {"models": []} - - for model in models["models"]: - add_custom_info_to_model(model) app.state.MODELS = {model["model"]: model for model in models["models"]} return models -def add_custom_info_to_model(model: dict): - model["custom_info"] = next( - (item for item in app.state.MODEL_CONFIG if item.id == model["model"]), None - ) - - @app.get("/api/tags") @app.get("/api/tags/{url_idx}") async def get_ollama_tags( diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 0e2f28409..467164d97 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -52,8 +52,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.MODEL_CONFIG = Models.get_all_models() - app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS @@ -207,7 +205,13 @@ def merge_models_lists(model_lists): if models is not None and "error" not in models: merged_list.extend( [ - {**model, "urlIdx": idx} + { + **model, + "name": model["id"], + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } for model in models if "api.openai.com" not in app.state.config.OPENAI_API_BASE_URLS[idx] @@ -250,21 +254,12 @@ async def get_all_models(): ) } - for model in models["data"]: - add_custom_info_to_model(model) - log.info(f"models: {models}") app.state.MODELS = {model["id"]: model for model in models["data"]} return models -def add_custom_info_to_model(model: dict): - model["custom_info"] = next( - (item for item in app.state.MODEL_CONFIG if item.id == model["id"]), None - ) - - @app.get("/models") @app.get("/models/{url_idx}") async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): diff --git a/backend/main.py b/backend/main.py index e19ab57fa..24442488a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -19,8 +19,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse, Response -from apps.ollama.main import app as ollama_app -from apps.openai.main import app as openai_app +from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models +from apps.openai.main import app as openai_app, get_all_models as get_openai_models from apps.litellm.main import ( app as litellm_app, @@ -39,7 +39,7 @@ from pydantic import BaseModel from typing import List, Optional from apps.web.models.models import Models, ModelModel -from utils.utils import get_admin_user +from utils.utils import get_admin_user, get_verified_user from apps.rag.utils import rag_messages from config import ( @@ -53,6 +53,8 @@ from config import ( FRONTEND_BUILD_DIR, CACHE_DIR, STATIC_DIR, + ENABLE_OPENAI_API, + ENABLE_OLLAMA_API, ENABLE_LITELLM, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, @@ -110,10 +112,13 @@ app = FastAPI( ) app.state.config = AppConfig() + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API + app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST -app.state.MODEL_CONFIG = Models.get_all_models() app.state.config.WEBHOOK_URL = WEBHOOK_URL @@ -249,9 +254,11 @@ async def update_embedding_function(request: Request, call_next): return response +# TODO: Deprecate LiteLLM app.mount("/litellm/api", litellm_app) + app.mount("/ollama", ollama_app) -app.mount("/openai/api", openai_app) +app.mount("/openai", openai_app) app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) @@ -262,6 +269,72 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION +@app.get("/api/models") +async def get_models(user=Depends(get_verified_user)): + openai_models = [] + ollama_models = [] + + if app.state.config.ENABLE_OPENAI_API: + openai_models = await get_openai_models() + openai_app.state.MODELS = openai_models + + openai_models = openai_models["data"] + + if app.state.config.ENABLE_OLLAMA_API: + ollama_models = await get_ollama_models() + ollama_app.state.MODELS = ollama_models + + print(ollama_models) + + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + print("openai", openai_models) + print("ollama", ollama_models) + + models = openai_models + ollama_models + custom_models = Models.get_all_models() + + for custom_model in custom_models: + if custom_model.base_model_id == None: + for model in models: + if custom_model.id == model["id"]: + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + else: + models.append( + { + "id": custom_model.id, + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": "user", + "info": custom_model.model_dump(), + } + ) + + if app.state.config.ENABLE_MODEL_FILTER: + if user.role == "user": + models = list( + filter( + lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, + models, + ) + ) + return {"data": models} + + return {"data": models} + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a7b59a7ca..9d776ff7e 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,5 +1,33 @@ import { WEBUI_BASE_URL } from '$lib/constants'; +export const getModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res?.data ?? []; +}; + export const getBackendConfig = async () => { let error = null; diff --git a/src/lib/components/chat/MessageInput/Models.svelte b/src/lib/components/chat/MessageInput/Models.svelte index 652a73a2e..71096f1c5 100644 --- a/src/lib/components/chat/MessageInput/Models.svelte +++ b/src/lib/components/chat/MessageInput/Models.svelte @@ -21,10 +21,8 @@ let filteredModels = []; $: filteredModels = $models - .filter((p) => - (p.custom_info?.name ?? p.name).includes(prompt.split(' ')?.at(0)?.substring(1) ?? '') - ) - .sort((a, b) => (a.custom_info?.name ?? a.name).localeCompare(b.custom_info?.name ?? b.name)); + .filter((p) => p.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')) + .sort((a, b) => a.name.localeCompare(b.name)); $: if (prompt) { selectedIdx = 0; @@ -158,7 +156,7 @@ on:focus={() => {}} >