feat: s3 support

This commit is contained in:
Timothy J. Baek 2024-10-20 23:38:26 -07:00
parent cb86e09005
commit 7984980619
2 changed files with 183 additions and 135 deletions

View File

@ -1,14 +1,19 @@
import logging import logging
import os import os
import shutil
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
import mimetypes import mimetypes
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import FileForm, FileModel, Files from open_webui.apps.webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR from open_webui.config import UPLOAD_DIR
@ -44,18 +49,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
name = filename name = filename
filename = f"{id}_{filename}" filename = f"{id}_{filename}"
file_path = f"{UPLOAD_DIR}/{filename}" contents, file_path = Storage.upload_file(file.file, filename)
contents = file.file.read()
if len(contents) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.EMPTY_CONTENT,
)
with open(file_path, "wb") as f:
f.write(contents)
f.close()
file = Files.insert_new_file( file = Files.insert_new_file(
user.id, user.id,
@ -101,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
############################ ############################
@router.get("/", response_model=list[FileModel]) @router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user)): async def list_files(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
files = Files.get_files() files = Files.get_files()
@ -118,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)):
@router.delete("/all") @router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)): async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files() result = Files.delete_all_files()
if result: if result:
folder = f"{UPLOAD_DIR}"
try: try:
# Check if the directory exists Storage.delete_all_files()
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e: except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}") log.exception(e)
else: log.error(f"Error deleting files")
print(f"The directory {folder} does not exist") raise HTTPException(
except Exception as e: status_code=status.HTTP_400_BAD_REQUEST,
print(f"Failed to process the directory {folder}. Reason: {e}") detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "All files deleted successfully"} return {"message": "All files deleted successfully"}
else: else:
raise HTTPException( raise HTTPException(
@ -222,9 +205,10 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content") @router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
file_path = Path(file.path) try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
if file_path.is_file(): if file_path.is_file():
@ -238,6 +222,13 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -252,6 +243,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.path file_path = file.path
if file_path: if file_path:
file_path = Storage.get_file(file_path)
file_path = Path(file_path) file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
@ -298,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
result = Files.delete_file_by_id(id) result = Files.delete_file_by_id(id)
if result: if result:
try:
Storage.delete_file(file.filename)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "File deleted successfully"} return {"message": "File deleted successfully"}
else: else:
raise HTTPException( raise HTTPException(

View File

@ -1,6 +1,12 @@
import os import os
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ( from open_webui.config import (
STORAGE_PROVIDER, STORAGE_PROVIDER,
S3_ACCESS_KEY_ID, S3_ACCESS_KEY_ID,
@ -9,19 +15,28 @@ from open_webui.config import (
S3_REGION_NAME, S3_REGION_NAME,
S3_ENDPOINT_URL, S3_ENDPOINT_URL,
UPLOAD_DIR, UPLOAD_DIR,
AppConfig,
) )
class StorageProvider: import boto3
def __init__(self): from boto3.s3 import S3Client
self.storage_provider = None from botocore.exceptions import ClientError
self.client = None from typing import BinaryIO, Tuple, Optional
self.bucket_name = None
if STORAGE_PROVIDER == "s3":
self.storage_provider = "s3" class StorageProvider:
self.client = boto3.client( def __init__(self, provider: Optional[str] = None):
self.storage_provider: str = provider or STORAGE_PROVIDER
self.s3_client = None
self.s3_bucket_name: Optional[str] = None
if self.storage_provider == "s3":
self._initialize_s3()
def _initialize_s3(self) -> None:
"""Initializes the S3 client and bucket name if using S3 storage."""
self.s3_client = boto3.client(
"s3", "s3",
region_name=S3_REGION_NAME, region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL, endpoint_url=S3_ENDPOINT_URL,
@ -29,89 +44,121 @@ class StorageProvider:
aws_secret_access_key=S3_SECRET_ACCESS_KEY, aws_secret_access_key=S3_SECRET_ACCESS_KEY,
) )
self.bucket_name = S3_BUCKET_NAME self.bucket_name = S3_BUCKET_NAME
else:
self.storage_provider = "local"
def get_storage_provider(self): def _upload_to_s3(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
return self.storage_provider """Handles uploading of the file to S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
def upload_file(self, file, filename):
if self.storage_provider == "s3":
try: try:
bucket = self.bucket_name self.s3_client.upload_fileobj(file, self.bucket_name, filename)
self.client.upload_fileobj(file, bucket, filename) return file.read(), f"s3://{self.bucket_name}/{filename}"
return filename
except ClientError as e: except ClientError as e:
raise RuntimeError(f"Error uploading file: {e}") raise RuntimeError(f"Error uploading file to S3: {e}")
else:
file_path = os.path.join(UPLOAD_DIR, filename) def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
os.makedirs(os.path.dirname(file_path), exist_ok=True) """Handles uploading of the file to local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(file.read()) f.write(contents)
return filename return contents, file_path
def _get_file_from_s3(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
def list_files(self):
if self.storage_provider == "s3":
try: try:
bucket = self.bucket_name bucket_name, key = file_path.split("//")[1].split("/")
response = self.client.list_objects_v2(Bucket=bucket) local_file_path = f"{UPLOAD_DIR}/{key}"
if "Contents" in response: self.s3_client.download_file(bucket_name, key, local_file_path)
return [content["Key"] for content in response["Contents"]] return local_file_path
return []
except ClientError as e: except ClientError as e:
raise RuntimeError(f"Error listing files: {e}") raise RuntimeError(f"Error downloading file from S3: {e}")
else:
return [
f
for f in os.listdir(UPLOAD_DIR)
if os.path.isfile(os.path.join(UPLOAD_DIR, f))
]
def get_file(self, filename): def _get_file_from_local(self, file_path: str) -> str:
if self.storage_provider == "s3": """Handles downloading of the file from local storage."""
try:
bucket = self.bucket_name
file_path = f"/tmp/{filename}"
self.client.download_file(bucket, filename, file_path)
return file_path return file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file: {e}")
else:
file_path = os.path.join(UPLOAD_DIR, filename)
if os.path.isfile(file_path):
return file_path
else:
raise FileNotFoundError(f"File {filename} not found in local storage.")
def delete_file(self, filename): def _delete_from_s3(self, filename: str) -> None:
if self.storage_provider == "s3": """Handles deletion of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try: try:
bucket = self.bucket_name self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
self.client.delete_object(Bucket=bucket, Key=filename)
except ClientError as e: except ClientError as e:
raise RuntimeError(f"Error deleting file: {e}") raise RuntimeError(f"Error deleting file from S3: {e}")
else:
file_path = os.path.join(UPLOAD_DIR, filename) def _delete_from_local(self, filename: str) -> None:
"""Handles deletion of the file from local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path): if os.path.isfile(file_path):
os.remove(file_path) os.remove(file_path)
else: else:
raise FileNotFoundError(f"File {filename} not found in local storage.") raise FileNotFoundError(f"File {filename} not found in local storage.")
def delete_all_files(self): def _delete_all_from_s3(self) -> None:
if self.storage_provider == "s3": """Handles deletion of all files from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try: try:
bucket = self.bucket_name response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
response = self.client.list_objects_v2(Bucket=bucket)
if "Contents" in response: if "Contents" in response:
for content in response["Contents"]: for content in response["Contents"]:
self.client.delete_object(Bucket=bucket, Key=content["Key"]) self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e: except ClientError as e:
raise RuntimeError(f"Error deleting all files: {e}") raise RuntimeError(f"Error deleting all files from S3: {e}")
else:
def _delete_all_from_local(self) -> None:
"""Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR):
file_path = os.path.join(UPLOAD_DIR, filename) file_path = os.path.join(UPLOAD_DIR, filename)
if os.path.isfile(file_path): try:
os.remove(file_path) if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
raise FileNotFoundError(
f"Directory {UPLOAD_DIR} not found in local storage."
)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
if self.storage_provider == "s3":
return self._upload_to_s3(file, filename)
return self._upload_to_local(contents, filename)
def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path."""
if self.storage_provider == "s3":
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
def delete_file(self, filename: str) -> None:
"""Deletes a file either from S3 or the local file system."""
if self.storage_provider == "s3":
self._delete_from_s3(filename)
else:
self._delete_from_local(filename)
def delete_all_files(self) -> None:
"""Deletes all files from the storage."""
if self.storage_provider == "s3":
self._delete_all_from_s3()
else:
self._delete_all_from_local()
Storage = StorageProvider() Storage = StorageProvider(provider=STORAGE_PROVIDER)