Parallelize base model fetching

This commit is contained in:
toriset 2025-05-27 15:35:16 +03:00 committed by GitHub
parent 40bea00e3d
commit 27de981246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import time
import logging
import asyncio
import sys
from aiocache import cached
@ -33,18 +34,9 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def get_all_base_models(request: Request, user: UserModel = None):
function_models = []
openai_models = []
ollama_models = []
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request, user=user)
openai_models = openai_models["data"]
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request, user=user)
ollama_models = [
async def fetch_ollama_models(request, user):
raw_ollama_models = await ollama.get_all_models(request, user=user)
return [
{
"id": model["model"],
"name": model["name"],
@ -55,13 +47,33 @@ async def get_all_base_models(request: Request, user: UserModel = None):
"connection_type": model.get("connection_type", "local"),
"tags": model.get("tags", []),
}
for model in ollama_models["models"]
for model in raw_ollama_models["models"]
]
function_models = await get_function_models(request)
models = function_models + openai_models + ollama_models
return models
async def fetch_openai_models(request, user):
openai_response = await openai.get_all_models(request, user=user)
return openai_response["data"]
async def get_all_base_models(request: Request, user: UserModel = None):
openai_task = (
fetch_openai_models(request, user)
if request.app.state.config.ENABLE_OPENAI_API
else asyncio.sleep(0, result=[])
)
ollama_task = (
fetch_ollama_models(request, user)
if request.app.state.config.ENABLE_OLLAMA_API
else asyncio.sleep(0, result=[])
)
function_task = get_function_models(request)
openai_models, ollama_models, function_models = await asyncio.gather(
openai_task, ollama_task, function_task
)
return function_models + openai_models + ollama_models
async def get_all_models(request, user: UserModel = None):