From b440a11d32b72b3be27754d0aefc136da693228f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 29 May 2024 20:40:50 -0700 Subject: [PATCH] refac --- main.py | 260 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 130 insertions(+), 130 deletions(-) diff --git a/main.py b/main.py index e8c86c9..67e7c2a 100644 --- a/main.py +++ b/main.py @@ -245,6 +245,136 @@ async def get_models(): } +@app.get("/v1") +@app.get("/") +async def get_status(): + return {"status": True} + + +@app.get("/v1/pipelines") +@app.get("/pipelines") +async def list_pipelines(user: str = Depends(get_current_user)): + if user == API_KEY: + return { + "data": [ + {"id": pipeline_id, "name": PIPELINE_NAMES[pipeline_id]} + for pipeline_id in list(PIPELINE_MODULES.keys()) + ] + } + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + +class AddPipelineForm(BaseModel): + url: str + + +async def download_file(url: str, dest_folder: str): + filename = os.path.basename(urlparse(url).path) + if not filename.endswith(".py"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="URL must point to a Python file", + ) + + file_path = os.path.join(dest_folder, filename) + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status != 200: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to download file", + ) + with open(file_path, "wb") as f: + f.write(await response.read()) + + return file_path + + +@app.post("/v1/pipelines/add") +@app.post("/pipelines/add") +async def add_pipeline( + form_data: AddPipelineForm, user: str = Depends(get_current_user) +): + if user != API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + try: + url = convert_to_raw_url(form_data.url) + + print(url) + file_path = await download_file(url, dest_folder=PIPELINES_DIR) + await reload() + return { + "status": True, + "detail": f"Pipeline added successfully from {file_path}", + } + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +class DeletePipelineForm(BaseModel): + id: str + + +@app.delete("/v1/pipelines/delete") +@app.delete("/pipelines/delete") +async def delete_pipeline( + form_data: DeletePipelineForm, user: str = Depends(get_current_user) +): + if user != API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + pipeline_id = form_data.id + pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None) + + if PIPELINE_MODULES[pipeline_id]: + if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"): + await PIPELINE_MODULES[pipeline_id].on_shutdown() + + pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py") + if os.path.exists(pipeline_path): + os.remove(pipeline_path) + await reload() + return { + "status": True, + "detail": f"Pipeline {pipeline_id} deleted successfully", + } + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Pipeline {pipeline_id} not found", + ) + + +@app.post("/v1/pipelines/reload") +@app.post("/pipelines/reload") +async def reload_pipelines(user: str = Depends(get_current_user)): + if user == API_KEY: + await reload() + return {"message": "Pipelines reloaded successfully."} + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + @app.get("/v1/{pipeline_id}/valves") @app.get("/{pipeline_id}/valves") async def get_valves(pipeline_id: str): @@ -486,133 +616,3 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): } return await run_in_threadpool(job) - - -@app.get("/v1") -@app.get("/") -async def get_status(): - return {"status": True} - - -@app.get("/v1/pipelines") -@app.get("/pipelines") -async def list_pipelines(user: str = Depends(get_current_user)): - if user == API_KEY: - return { - "data": [ - {"id": pipeline_id, "name": PIPELINE_NAMES[pipeline_id]} - for pipeline_id in list(PIPELINE_MODULES.keys()) - ] - } - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - ) - - -class AddPipelineForm(BaseModel): - url: str - - -async def download_file(url: str, dest_folder: str): - filename = os.path.basename(urlparse(url).path) - if not filename.endswith(".py"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="URL must point to a Python file", - ) - - file_path = os.path.join(dest_folder, filename) - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status != 200: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Failed to download file", - ) - with open(file_path, "wb") as f: - f.write(await response.read()) - - return file_path - - -@app.post("/v1/pipelines/add") -@app.post("/pipelines/add") -async def add_pipeline( - form_data: AddPipelineForm, user: str = Depends(get_current_user) -): - if user != API_KEY: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - ) - - try: - url = convert_to_raw_url(form_data.url) - - print(url) - file_path = await download_file(url, dest_folder=PIPELINES_DIR) - await reload() - return { - "status": True, - "detail": f"Pipeline added successfully from {file_path}", - } - except HTTPException as e: - raise e - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e), - ) - - -class DeletePipelineForm(BaseModel): - id: str - - -@app.delete("/v1/pipelines/delete") -@app.delete("/pipelines/delete") -async def delete_pipeline( - form_data: DeletePipelineForm, user: str = Depends(get_current_user) -): - if user != API_KEY: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - ) - - pipeline_id = form_data.id - pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None) - - if PIPELINE_MODULES[pipeline_id]: - if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"): - await PIPELINE_MODULES[pipeline_id].on_shutdown() - - pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py") - if os.path.exists(pipeline_path): - os.remove(pipeline_path) - await reload() - return { - "status": True, - "detail": f"Pipeline {pipeline_id} deleted successfully", - } - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Pipeline {pipeline_id} not found", - ) - - -@app.post("/v1/pipelines/reload") -@app.post("/pipelines/reload") -async def reload_pipelines(user: str = Depends(get_current_user)): - if user == API_KEY: - await reload() - return {"message": "Pipelines reloaded successfully."} - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - )