From 24c0dbec0daf95b44c896a21fce09703b6b08958 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 4 May 2024 01:31:08 -0700 Subject: [PATCH] fix: pending permission issue --- backend/apps/ollama/main.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f1c836faa..6d7e4f815 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -31,7 +31,12 @@ from typing import Optional, List, Union from apps.web.models.users import Users from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user, get_admin_user +from utils.utils import ( + decode_token, + get_current_user, + get_verified_user, + get_admin_user, +) from config import ( @@ -164,7 +169,7 @@ async def get_all_models(): @app.get("/api/tags") @app.get("/api/tags/{url_idx}") async def get_ollama_tags( - url_idx: Optional[int] = None, user=Depends(get_current_user) + url_idx: Optional[int] = None, user=Depends(get_verified_user) ): if url_idx == None: models = await get_all_models() @@ -563,7 +568,7 @@ async def delete_model( @app.post("/api/show") -async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): +async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): if form_data.name not in app.state.MODELS: raise HTTPException( status_code=400, @@ -612,7 +617,7 @@ class GenerateEmbeddingsForm(BaseModel): async def generate_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: model = form_data.model @@ -730,7 +735,7 @@ class GenerateCompletionForm(BaseModel): async def generate_completion( form_data: GenerateCompletionForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: @@ -833,7 +838,7 @@ class GenerateChatCompletionForm(BaseModel): async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: @@ -942,7 +947,7 @@ class OpenAIChatCompletionForm(BaseModel): async def generate_openai_chat_completion( form_data: OpenAIChatCompletionForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: @@ -1241,7 +1246,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): +async def deprecated_proxy( + path: str, request: Request, user=Depends(get_verified_user) +): url = app.state.OLLAMA_BASE_URLS[0] target_url = f"{url}/{path}"