refac: ollama ps

This commit is contained in:
Timothy Jaeryang Baek 2025-05-23 18:47:50 +04:00
parent 65d997a6c5
commit a50a8e2ef9

View File

@ -300,6 +300,22 @@ async def update_config(
}
def merge_ollama_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
@ -364,23 +380,8 @@ async def get_all_models(request: Request, user: UserModel = None):
if connection_type:
model["connection_type"] = connection_type
def merge_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
models = {
"models": merge_models_lists(
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
@ -468,6 +469,72 @@ async def get_ollama_tags(
return models
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(f"{url}/api/ps", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
models = {
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
)
)
}
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
models["models"] = await get_filtered_models(models, user)
else:
models = {"models": []}
return models
@router.get("/api/version")
@router.get("/api/version/{url_idx}")
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
@ -541,32 +608,6 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
return {"version": False}
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = [
send_get_request(
f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
user=user,
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*request_tasks)
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else:
return {}
class ModelNameForm(BaseModel):
name: str