From d08a1573da10b7f614722c1cc4c7c03966873e5b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 5 Jun 2024 13:25:07 -0700 Subject: [PATCH] feat: upload endpoint --- main.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 04626fd..7c6d7e9 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Request, Depends, status, HTTPException +from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool @@ -17,6 +17,7 @@ from concurrent.futures import ThreadPoolExecutor from schemas import FilterForm, OpenAIChatCompletionForm from urllib.parse import urlparse +import shutil import aiohttp import os import importlib.util @@ -372,6 +373,51 @@ async def add_pipeline( ) +@app.post("/v1/pipelines/upload") +@app.post("/pipelines/upload") +async def upload_pipeline( + file: UploadFile = File(...), user: str = Depends(get_current_user) +): + if user != API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + file_ext = os.path.splitext(file.filename)[1] + if file_ext != ".py": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only Python files are allowed.", + ) + + try: + # Ensure the destination folder exists + os.makedirs(PIPELINES_DIR, exist_ok=True) + + # Define the file path + file_path = os.path.join(PIPELINES_DIR, file.filename) + + # Save the uploaded file to the specified directory + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # Perform any necessary reload or processing + await reload() + + return { + "status": True, + "detail": f"Pipeline uploaded successfully to {file_path}", + } + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + class DeletePipelineForm(BaseModel): id: str