Fix implicit None returns in get_model_by_id, get_knowledge_by_id, get_tools_by_id, and get_prompt_by_command. Now properly returns 401 for access denied and 404 for not found instead of silently returning None.
436 lines
13 KiB
Python
436 lines
13 KiB
Python
from typing import Optional
|
|
import io
|
|
import base64
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
|
|
from open_webui.models.groups import Groups
|
|
from open_webui.models.models import (
|
|
ModelForm,
|
|
ModelModel,
|
|
ModelResponse,
|
|
ModelListResponse,
|
|
Models,
|
|
)
|
|
|
|
from pydantic import BaseModel
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
Request,
|
|
status,
|
|
Response,
|
|
)
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.access_control import has_access, has_permission
|
|
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
|
from open_webui.internal.db import get_session
|
|
from sqlalchemy.orm import Session
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def is_valid_model_id(model_id: str) -> bool:
|
|
return model_id and len(model_id) <= 256
|
|
|
|
|
|
###########################
|
|
# GetModels
|
|
###########################
|
|
|
|
|
|
PAGE_ITEM_COUNT = 30
|
|
|
|
|
|
@router.get(
|
|
"/list", response_model=ModelListResponse
|
|
) # do NOT use "/" as path, conflicts with main.py
|
|
async def get_models(
|
|
query: Optional[str] = None,
|
|
view_option: Optional[str] = None,
|
|
tag: Optional[str] = None,
|
|
order_by: Optional[str] = None,
|
|
direction: Optional[str] = None,
|
|
page: Optional[int] = 1,
|
|
user=Depends(get_verified_user),
|
|
db: Session = Depends(get_session),
|
|
):
|
|
|
|
limit = PAGE_ITEM_COUNT
|
|
|
|
page = max(1, page)
|
|
skip = (page - 1) * limit
|
|
|
|
filter = {}
|
|
if query:
|
|
filter["query"] = query
|
|
if view_option:
|
|
filter["view_option"] = view_option
|
|
if tag:
|
|
filter["tag"] = tag
|
|
if order_by:
|
|
filter["order_by"] = order_by
|
|
if direction:
|
|
filter["direction"] = direction
|
|
|
|
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
|
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
|
if groups:
|
|
filter["group_ids"] = [group.id for group in groups]
|
|
|
|
filter["user_id"] = user.id
|
|
|
|
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
|
|
|
|
|
|
###########################
|
|
# GetBaseModels
|
|
###########################
|
|
|
|
|
|
@router.get("/base", response_model=list[ModelResponse])
|
|
async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
|
return Models.get_base_models(db=db)
|
|
|
|
|
|
###########################
|
|
# GetModelTags
|
|
###########################
|
|
|
|
|
|
@router.get("/tags", response_model=list[str])
|
|
async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
|
models = Models.get_models(db=db)
|
|
else:
|
|
models = Models.get_models_by_user_id(user.id, db=db)
|
|
|
|
tags_set = set()
|
|
for model in models:
|
|
if model.meta:
|
|
meta = model.meta.model_dump()
|
|
for tag in meta.get("tags", []):
|
|
tags_set.add((tag.get("name")))
|
|
|
|
tags = [tag for tag in tags_set]
|
|
tags.sort()
|
|
return tags
|
|
|
|
|
|
############################
|
|
# CreateNewModel
|
|
############################
|
|
|
|
|
|
@router.post("/create", response_model=Optional[ModelModel])
|
|
async def create_new_model(
|
|
request: Request,
|
|
form_data: ModelForm,
|
|
user=Depends(get_verified_user),
|
|
db: Session = Depends(get_session),
|
|
):
|
|
if user.role != "admin" and not has_permission(
|
|
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
model = Models.get_model_by_id(form_data.id, db=db)
|
|
if model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
|
)
|
|
|
|
if not is_valid_model_id(form_data.id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG,
|
|
)
|
|
|
|
else:
|
|
model = Models.insert_new_model(form_data, user.id, db=db)
|
|
if model:
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.DEFAULT(),
|
|
)
|
|
|
|
|
|
############################
|
|
# ExportModels
|
|
############################
|
|
|
|
|
|
@router.get("/export", response_model=list[ModelModel])
|
|
async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
|
if user.role != "admin" and not has_permission(
|
|
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
|
return Models.get_models(db=db)
|
|
else:
|
|
return Models.get_models_by_user_id(user.id, db=db)
|
|
|
|
|
|
############################
|
|
# ImportModels
|
|
############################
|
|
|
|
|
|
class ModelsImportForm(BaseModel):
|
|
models: list[dict]
|
|
|
|
|
|
@router.post("/import", response_model=bool)
|
|
async def import_models(
|
|
request: Request,
|
|
user=Depends(get_verified_user),
|
|
form_data: ModelsImportForm = (...),
|
|
db: Session = Depends(get_session),
|
|
):
|
|
if user.role != "admin" and not has_permission(
|
|
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
try:
|
|
data = form_data.models
|
|
if isinstance(data, list):
|
|
for model_data in data:
|
|
# Here, you can add logic to validate model_data if needed
|
|
model_id = model_data.get("id")
|
|
|
|
if model_id and is_valid_model_id(model_id):
|
|
existing_model = Models.get_model_by_id(model_id, db=db)
|
|
if existing_model:
|
|
# Update existing model
|
|
model_data["meta"] = model_data.get("meta", {})
|
|
model_data["params"] = model_data.get("params", {})
|
|
|
|
updated_model = ModelForm(
|
|
**{**existing_model.model_dump(), **model_data}
|
|
)
|
|
Models.update_model_by_id(model_id, updated_model, db=db)
|
|
else:
|
|
# Insert new model
|
|
model_data["meta"] = model_data.get("meta", {})
|
|
model_data["params"] = model_data.get("params", {})
|
|
new_model = ModelForm(**model_data)
|
|
Models.insert_new_model(user_id=user.id, form_data=new_model, db=db)
|
|
return True
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
############################
|
|
# SyncModels
|
|
############################
|
|
|
|
|
|
class SyncModelsForm(BaseModel):
|
|
models: list[ModelModel] = []
|
|
|
|
|
|
@router.post("/sync", response_model=list[ModelModel])
|
|
async def sync_models(
|
|
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session)
|
|
):
|
|
return Models.sync_models(user.id, form_data.models, db=db)
|
|
|
|
|
|
###########################
|
|
# GetModelById
|
|
###########################
|
|
|
|
|
|
class ModelIdForm(BaseModel):
|
|
id: str
|
|
|
|
|
|
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
|
@router.get("/model", response_model=Optional[ModelResponse])
|
|
async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
|
model = Models.get_model_by_id(id, db=db)
|
|
if model:
|
|
if (
|
|
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
|
or model.user_id == user.id
|
|
or has_access(user.id, "read", model.access_control, db=db)
|
|
):
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
|
|
###########################
|
|
# GetModelById
|
|
###########################
|
|
|
|
|
|
@router.get("/model/profile/image")
|
|
async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
|
model = Models.get_model_by_id(id, db=db)
|
|
# Cache-control headers to prevent stale cached images
|
|
cache_headers = {"Cache-Control": "no-cache, must-revalidate"}
|
|
|
|
if model:
|
|
if model.meta.profile_image_url:
|
|
if model.meta.profile_image_url.startswith("http"):
|
|
return Response(
|
|
status_code=status.HTTP_302_FOUND,
|
|
headers={"Location": model.meta.profile_image_url, **cache_headers},
|
|
)
|
|
elif model.meta.profile_image_url.startswith("data:image"):
|
|
try:
|
|
header, base64_data = model.meta.profile_image_url.split(",", 1)
|
|
image_data = base64.b64decode(base64_data)
|
|
image_buffer = io.BytesIO(image_data)
|
|
media_type = header.split(";")[0].lstrip("data:")
|
|
|
|
return StreamingResponse(
|
|
image_buffer,
|
|
media_type=media_type,
|
|
headers={
|
|
"Content-Disposition": "inline",
|
|
**cache_headers,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
|
|
return FileResponse(f"{STATIC_DIR}/favicon.png", headers=cache_headers)
|
|
else:
|
|
return FileResponse(f"{STATIC_DIR}/favicon.png", headers=cache_headers)
|
|
|
|
|
|
############################
|
|
# ToggleModelById
|
|
############################
|
|
|
|
|
|
@router.post("/model/toggle", response_model=Optional[ModelResponse])
|
|
async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
|
model = Models.get_model_by_id(id, db=db)
|
|
if model:
|
|
if (
|
|
user.role == "admin"
|
|
or model.user_id == user.id
|
|
or has_access(user.id, "write", model.access_control, db=db)
|
|
):
|
|
model = Models.toggle_model_by_id(id, db=db)
|
|
|
|
if model:
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
|
|
############################
|
|
# UpdateModelById
|
|
############################
|
|
|
|
|
|
@router.post("/model/update", response_model=Optional[ModelModel])
|
|
async def update_model_by_id(
|
|
form_data: ModelForm,
|
|
user=Depends(get_verified_user),
|
|
db: Session = Depends(get_session),
|
|
):
|
|
model = Models.get_model_by_id(form_data.id, db=db)
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
if (
|
|
model.user_id != user.id
|
|
and not has_access(user.id, "write", model.access_control, db=db)
|
|
and user.role != "admin"
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
|
|
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db)
|
|
return model
|
|
|
|
|
|
############################
|
|
# DeleteModelById
|
|
############################
|
|
|
|
|
|
@router.post("/model/delete", response_model=bool)
|
|
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)):
|
|
model = Models.get_model_by_id(form_data.id, db=db)
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
if (
|
|
user.role != "admin"
|
|
and model.user_id != user.id
|
|
and not has_access(user.id, "write", model.access_control, db=db)
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
result = Models.delete_model_by_id(form_data.id, db=db)
|
|
return result
|
|
|
|
|
|
@router.delete("/delete/all", response_model=bool)
|
|
async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)):
|
|
result = Models.delete_all_models(db=db)
|
|
return result
|