diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 4d31ac2f9..297d9e5a2 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -97,7 +97,9 @@ async def get_models( @router.get("/base", response_model=list[ModelResponse]) -async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def get_base_models( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): return Models.get_base_models(db=db) @@ -107,7 +109,9 @@ async def get_base_models(user=Depends(get_admin_user), db: Session = Depends(ge @router.get("/tags", response_model=list[str]) -async def get_model_tags(user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_model_tags( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: models = Models.get_models(db=db) else: @@ -175,9 +179,16 @@ async def create_new_model( @router.get("/export", response_model=list[ModelModel]) -async def export_models(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def export_models( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS, db=db + user.id, + "workspace.models_export", + request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -207,7 +218,10 @@ async def import_models( db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS, db=db + user.id, + "workspace.models_import", + request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -236,7 +250,9 @@ async def import_models( model_data["meta"] = model_data.get("meta", {}) model_data["params"] = model_data.get("params", {}) new_model = ModelForm(**model_data) - Models.insert_new_model(user_id=user.id, form_data=new_model, db=db) + Models.insert_new_model( + user_id=user.id, form_data=new_model, db=db + ) return True else: raise HTTPException(status_code=400, detail="Invalid JSON format") @@ -256,7 +272,10 @@ class SyncModelsForm(BaseModel): @router.post("/sync", response_model=list[ModelModel]) async def sync_models( - request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + form_data: SyncModelsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): return Models.sync_models(user.id, form_data.models, db=db) @@ -272,7 +291,9 @@ class ModelIdForm(BaseModel): # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id @router.get("/model", response_model=Optional[ModelResponse]) -async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_model_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): model = Models.get_model_by_id(id, db=db) if model: if ( @@ -299,17 +320,17 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user), db: Session @router.get("/model/profile/image") -async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_model_profile_image( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): model = Models.get_model_by_id(id, db=db) - # Cache-control headers to prevent stale cached images - cache_headers = {"Cache-Control": "no-cache, must-revalidate"} if model: if model.meta.profile_image_url: if model.meta.profile_image_url.startswith("http"): return Response( status_code=status.HTTP_302_FOUND, - headers={"Location": model.meta.profile_image_url, **cache_headers}, + headers={"Location": model.meta.profile_image_url}, ) elif model.meta.profile_image_url.startswith("data:image"): try: @@ -323,15 +344,14 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: media_type=media_type, headers={ "Content-Disposition": "inline", - **cache_headers, }, ) except Exception as e: pass - return FileResponse(f"{STATIC_DIR}/favicon.png", headers=cache_headers) + return FileResponse(f"{STATIC_DIR}/favicon.png") else: - return FileResponse(f"{STATIC_DIR}/favicon.png", headers=cache_headers) + return FileResponse(f"{STATIC_DIR}/favicon.png") ############################ @@ -340,7 +360,9 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user), db: @router.post("/model/toggle", response_model=Optional[ModelResponse]) -async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def toggle_model_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): model = Models.get_model_by_id(id, db=db) if model: if ( @@ -397,7 +419,9 @@ async def update_model_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db) + model = Models.update_model_by_id( + form_data.id, ModelForm(**form_data.model_dump()), db=db + ) return model @@ -407,7 +431,11 @@ async def update_model_by_id( @router.post("/model/delete", response_model=bool) -async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def delete_model_by_id( + form_data: ModelIdForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): model = Models.get_model_by_id(form_data.id, db=db) if not model: raise HTTPException( @@ -430,6 +458,8 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u @router.delete("/delete/all", response_model=bool) -async def delete_all_models(user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def delete_all_models( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): result = Models.delete_all_models(db=db) return result