From a28ad06bf00153979a790d26595cb8ff8f3d18e2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 16 Jun 2024 23:36:21 -0700 Subject: [PATCH] fix --- backend/apps/ollama/main.py | 149 ++---------------------------------- 1 file changed, 6 insertions(+), 143 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 81a3b2a0e..e82046e13 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -199,9 +199,6 @@ def merge_models_lists(model_lists): return list(merged_models.values()) -# user=Depends(get_current_user) - - async def get_all_models(): log.info("get_all_models()") @@ -1094,17 +1091,13 @@ async def download_file_stream( raise "Ollama: Could not create blob, Please try again." -# def number_generator(): -# for i in range(1, 101): -# yield f"data: {i}\n" - - # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" @app.post("/models/download") @app.post("/models/download/{url_idx}") async def download_model( form_data: UrlForm, url_idx: Optional[int] = None, + user=Depends(get_admin_user), ): allowed_hosts = ["https://huggingface.co/", "https://github.com/"] @@ -1133,7 +1126,11 @@ async def download_model( @app.post("/models/upload") @app.post("/models/upload/{url_idx}") -def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): +def upload_model( + file: UploadFile = File(...), + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): if url_idx == None: url_idx = 0 ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -1196,137 +1193,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): yield f"data: {json.dumps(res)}\n\n" return StreamingResponse(file_process_stream(), media_type="text/event-stream") - - -# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): -# if url_idx == None: -# url_idx = 0 -# url = app.state.config.OLLAMA_BASE_URLS[url_idx] - -# file_location = os.path.join(UPLOAD_DIR, file.filename) -# total_size = file.size - -# async def file_upload_generator(file): -# print(file) -# try: -# async with aiofiles.open(file_location, "wb") as f: -# completed_size = 0 -# while True: -# chunk = await file.read(1024*1024) -# if not chunk: -# break -# await f.write(chunk) -# completed_size += len(chunk) -# progress = (completed_size / total_size) * 100 - -# print(progress) -# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n' -# except Exception as e: -# print(e) -# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n" -# finally: -# await file.close() -# print("done") -# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n' - -# return StreamingResponse( -# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream" -# ) - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def deprecated_proxy( - path: str, request: Request, user=Depends(get_verified_user) -): - url = app.state.config.OLLAMA_BASE_URLS[0] - target_url = f"{url}/{path}" - - body = await request.body() - headers = dict(request.headers) - - if user.role in ["user", "admin"]: - if path in ["pull", "delete", "push", "copy", "create"]: - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - - headers.pop("host", None) - headers.pop("authorization", None) - headers.pop("origin", None) - headers.pop("referer", None) - - r = None - - def get_request(): - nonlocal r - - request_id = str(uuid.uuid4()) - try: - REQUEST_POOL.append(request_id) - - def stream_content(): - try: - if path == "generate": - data = json.loads(body.decode("utf-8")) - - if data.get("stream", True): - yield json.dumps({"id": request_id, "done": False}) + "\n" - - elif path == "chat": - yield json.dumps({"id": request_id, "done": False}) + "\n" - - for chunk in r.iter_content(chunk_size=8192): - if request_id in REQUEST_POOL: - yield chunk - else: - log.warning("User: canceled request") - break - finally: - if hasattr(r, "close"): - r.close() - if request_id in REQUEST_POOL: - REQUEST_POOL.remove(request_id) - - r = requests.request( - method=request.method, - url=target_url, - data=body, - headers=headers, - stream=True, - ) - - r.raise_for_status() - - # r.close() - - return StreamingResponse( - stream_content(), - status_code=r.status_code, - headers=dict(r.headers), - ) - except Exception as e: - raise e - - try: - return await run_in_threadpool(get_request) - except Exception as e: - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - )