diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index adb63f520..464680124 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -1,5 +1,6 @@ import time import logging +import asyncio import sys from aiocache import cached @@ -33,35 +34,46 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +async def fetch_ollama_models(request, user): + raw_ollama_models = await ollama.get_all_models(request, user=user) + return [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + "connection_type": model.get("connection_type", "local"), + "tags": model.get("tags", []), + } + for model in raw_ollama_models["models"] + ] + + +async def fetch_openai_models(request, user): + openai_response = await openai.get_all_models(request, user=user) + return openai_response["data"] + + async def get_all_base_models(request: Request, user: UserModel = None): - function_models = [] - openai_models = [] - ollama_models = [] + openai_task = ( + fetch_openai_models(request, user) + if request.app.state.config.ENABLE_OPENAI_API + else asyncio.sleep(0, result=[]) + ) + ollama_task = ( + fetch_ollama_models(request, user) + if request.app.state.config.ENABLE_OLLAMA_API + else asyncio.sleep(0, result=[]) + ) + function_task = get_function_models(request) - if request.app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models(request, user=user) - openai_models = openai_models["data"] + openai_models, ollama_models, function_models = await asyncio.gather( + openai_task, ollama_task, function_task + ) - if request.app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models(request, user=user) - ollama_models = [ - { - "id": model["model"], - "name": model["name"], - "object": "model", - "created": int(time.time()), - "owned_by": "ollama", - "ollama": model, - "connection_type": model.get("connection_type", "local"), - "tags": model.get("tags", []), - } - for model in ollama_models["models"] - ] - - function_models = await get_function_models(request) - models = function_models + openai_models + ollama_models - - return models + return function_models + openai_models + ollama_models async def get_all_models(request, user: UserModel = None):