mirror of
https://github.com/open-webui/open-webui
synced 2025-06-14 18:33:15 +00:00
refac: ollama ps
This commit is contained in:
parent
65d997a6c5
commit
a50a8e2ef9
@ -300,6 +300,22 @@ async def update_config(
|
||||
}
|
||||
|
||||
|
||||
def merge_ollama_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
@ -364,23 +380,8 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
@ -468,6 +469,72 @@ async def get_ollama_tags(
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/ps", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
if response:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
models = {
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models["models"] = await get_filtered_models(models, user)
|
||||
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/version")
|
||||
@router.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
@ -541,32 +608,6 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
return {"version": False}
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/ps",
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
user=user,
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user