diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 190d2d1c3..bdc6ec4f4 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -12,6 +12,7 @@ from apps.webui.routers import ( configs, memories, utils, + files, ) from config import ( WEBUI_BUILD_HASH, @@ -81,6 +82,7 @@ app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) +app.include_router(files.router, prefix="/files", tags=["files"]) @app.get("/") diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py new file mode 100644 index 000000000..c34bd46d8 --- /dev/null +++ b/backend/apps/webui/models/files.py @@ -0,0 +1,103 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time +import logging +from apps.webui.internal.db import DB, JSONField + +import json + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# Files DB Schema +#################### + + +class File(Model): + id = CharField(unique=True) + user_id = CharField() + filename = TextField() + meta = JSONField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class FileModel(BaseModel): + id: str + user_id: str + filename: str + meta: dict + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class FileResponse(BaseModel): + id: str + user_id: str + filename: str + meta: dict + created_at: int # timestamp in epoch + + +class FileForm(BaseModel): + id: str + filename: str + meta: dict = {} + + +class FilesTable: + def __init__(self, db): + self.db = db + self.db.create_tables([File]) + + def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: + file = FileModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + } + ) + + try: + result = File.create(**file.model_dump()) + if result: + return file + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") + return None + + def get_file_by_id(self, id: str) -> Optional[FileModel]: + try: + file = File.get(File.id == id) + return FileModel(**model_to_dict(file)) + except: + return None + + def get_files(self) -> List[FileModel]: + return [FileModel(**model_to_dict(file)) for file in File.select()] + + def delete_file_by_id(self, id: str) -> bool: + try: + query = File.delete().where((File.id == id)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Files = FilesTable(DB) diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py new file mode 100644 index 000000000..773386059 --- /dev/null +++ b/backend/apps/webui/routers/files.py @@ -0,0 +1,134 @@ +from fastapi import ( + Depends, + FastAPI, + HTTPException, + status, + Request, + UploadFile, + File, + Form, +) + + +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.files import Files, FileForm, FileModel, FileResponse +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +from importlib import util +import os +import uuid + +from config import SRC_LOG_LEVELS, UPLOAD_DIR + + +import logging + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +router = APIRouter() + +############################ +# Upload File +############################ + + +@router.post("/") +def upload_file( + file: UploadFile = File(...), + user=Depends(get_verified_user), +): + log.info(f"file.content_type: {file.content_type}") + try: + unsanitized_filename = file.filename + filename = os.path.basename(unsanitized_filename) + + # replace filename with uuid + id = str(uuid.uuid4()) + file_path = f"{UPLOAD_DIR}/{filename}" + + contents = file.file.read() + with open(file_path, "wb") as f: + f.write(contents) + f.close() + + file = Files.insert_new_file( + user.id, FileForm(**{"id": id, "filename": filename}) + ) + + if file: + return file + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), + ) + + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# List Files +############################ + + +@router.get("/", response_model=List[FileModel]) +async def list_files(user=Depends(get_verified_user)): + files = Files.get_files() + return files + + +############################ +# Get File By Id +############################ + + +@router.get("/{id}", response_model=Optional[FileModel]) +async def get_file_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) + + if file: + return file + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# Delete File By Id +############################ + + +@router.delete("/{id}") +async def delete_file_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) + + if file: + result = Files.delete_file_by_id(id) + if result: + return {"message": "File deleted successfully"} + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting file"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) diff --git a/backend/main.py b/backend/main.py index db1fa1640..7f82dd7c7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,6 +11,7 @@ import requests import mimetypes import shutil import os +import uuid import inspect import asyncio @@ -76,6 +77,7 @@ from config import ( VERSION, CHANGELOG, FRONTEND_BUILD_DIR, + UPLOAD_DIR, CACHE_DIR, STATIC_DIR, ENABLE_OPENAI_API, @@ -1378,6 +1380,7 @@ async def update_pipeline_valves( ) + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA