feat: s3 support

This commit is contained in:
Timothy J. Baek 2024-10-20 23:38:26 -07:00
parent cb86e09005
commit 7984980619
2 changed files with 183 additions and 135 deletions

View File

@ -1,14 +1,19 @@
import logging import logging
import os import os
import shutil
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
import mimetypes import mimetypes
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import FileForm, FileModel, Files from open_webui.apps.webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR from open_webui.config import UPLOAD_DIR
@ -44,18 +49,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
name = filename name = filename
filename = f"{id}_{filename}" filename = f"{id}_{filename}"
file_path = f"{UPLOAD_DIR}/{filename}" contents, file_path = Storage.upload_file(file.file, filename)
contents = file.file.read()
if len(contents) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.EMPTY_CONTENT,
)
with open(file_path, "wb") as f:
f.write(contents)
f.close()
file = Files.insert_new_file( file = Files.insert_new_file(
user.id, user.id,
@ -101,7 +95,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
############################ ############################
@router.get("/", response_model=list[FileModel]) @router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user)): async def list_files(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
files = Files.get_files() files = Files.get_files()
@ -118,27 +112,16 @@ async def list_files(user=Depends(get_verified_user)):
@router.delete("/all") @router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)): async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files() result = Files.delete_all_files()
if result: if result:
folder = f"{UPLOAD_DIR}"
try: try:
# Check if the directory exists Storage.delete_all_files()
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
except Exception as e: except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}") log.exception(e)
log.error(f"Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "All files deleted successfully"} return {"message": "All files deleted successfully"}
else: else:
raise HTTPException( raise HTTPException(
@ -222,21 +205,29 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content") @router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id) file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
file_path = Path(file.path) try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
if file_path.is_file(): if file_path.is_file():
print(f"file_path: {file_path}") print(f"file_path: {file_path}")
headers = { headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
} }
return FileResponse(file_path, headers=headers) return FileResponse(file_path, headers=headers)
else: else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
) )
else: else:
raise HTTPException( raise HTTPException(
@ -252,6 +243,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.path file_path = file.path
if file_path: if file_path:
file_path = Storage.get_file(file_path)
file_path = Path(file_path) file_path = Path(file_path)
# Check if the file already exists in the cache # Check if the file already exists in the cache
@ -298,6 +290,15 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"): if file and (file.user_id == user.id or user.role == "admin"):
result = Files.delete_file_by_id(id) result = Files.delete_file_by_id(id)
if result: if result:
try:
Storage.delete_file(file.filename)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
return {"message": "File deleted successfully"} return {"message": "File deleted successfully"}
else: else:
raise HTTPException( raise HTTPException(

View File

@ -1,6 +1,12 @@
import os import os
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import shutil
from typing import BinaryIO, Tuple, Optional, Union
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ( from open_webui.config import (
STORAGE_PROVIDER, STORAGE_PROVIDER,
S3_ACCESS_KEY_ID, S3_ACCESS_KEY_ID,
@ -9,109 +15,150 @@ from open_webui.config import (
S3_REGION_NAME, S3_REGION_NAME,
S3_ENDPOINT_URL, S3_ENDPOINT_URL,
UPLOAD_DIR, UPLOAD_DIR,
AppConfig,
) )
import boto3
from boto3.s3 import S3Client
from botocore.exceptions import ClientError
from typing import BinaryIO, Tuple, Optional
class StorageProvider: class StorageProvider:
def __init__(self): def __init__(self, provider: Optional[str] = None):
self.storage_provider = None self.storage_provider: str = provider or STORAGE_PROVIDER
self.client = None
self.bucket_name = None
if STORAGE_PROVIDER == "s3": self.s3_client = None
self.storage_provider = "s3" self.s3_bucket_name: Optional[str] = None
self.client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
else:
self.storage_provider = "local"
def get_storage_provider(self):
return self.storage_provider
def upload_file(self, file, filename):
if self.storage_provider == "s3": if self.storage_provider == "s3":
try: self._initialize_s3()
bucket = self.bucket_name
self.client.upload_fileobj(file, bucket, filename)
return filename
except ClientError as e:
raise RuntimeError(f"Error uploading file: {e}")
else:
file_path = os.path.join(UPLOAD_DIR, filename)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(file.read())
return filename
def list_files(self): def _initialize_s3(self) -> None:
if self.storage_provider == "s3": """Initializes the S3 client and bucket name if using S3 storage."""
try: self.s3_client = boto3.client(
bucket = self.bucket_name "s3",
response = self.client.list_objects_v2(Bucket=bucket) region_name=S3_REGION_NAME,
if "Contents" in response: endpoint_url=S3_ENDPOINT_URL,
return [content["Key"] for content in response["Contents"]] aws_access_key_id=S3_ACCESS_KEY_ID,
return [] aws_secret_access_key=S3_SECRET_ACCESS_KEY,
except ClientError as e: )
raise RuntimeError(f"Error listing files: {e}") self.bucket_name = S3_BUCKET_NAME
else:
return [
f
for f in os.listdir(UPLOAD_DIR)
if os.path.isfile(os.path.join(UPLOAD_DIR, f))
]
def get_file(self, filename): def _upload_to_s3(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
if self.storage_provider == "s3": """Handles uploading of the file to S3 storage."""
try: if not self.s3_client:
bucket = self.bucket_name raise RuntimeError("S3 Client is not initialized.")
file_path = f"/tmp/{filename}"
self.client.download_file(bucket, filename, file_path)
return file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file: {e}")
else:
file_path = os.path.join(UPLOAD_DIR, filename)
if os.path.isfile(file_path):
return file_path
else:
raise FileNotFoundError(f"File {filename} not found in local storage.")
def delete_file(self, filename): try:
if self.storage_provider == "s3": self.s3_client.upload_fileobj(file, self.bucket_name, filename)
try: return file.read(), f"s3://{self.bucket_name}/{filename}"
bucket = self.bucket_name except ClientError as e:
self.client.delete_object(Bucket=bucket, Key=filename) raise RuntimeError(f"Error uploading file to S3: {e}")
except ClientError as e:
raise RuntimeError(f"Error deleting file: {e}")
else:
file_path = os.path.join(UPLOAD_DIR, filename)
if os.path.isfile(file_path):
os.remove(file_path)
else:
raise FileNotFoundError(f"File {filename} not found in local storage.")
def delete_all_files(self): def _upload_to_local(self, contents: bytes, filename: str) -> Tuple[bytes, str]:
if self.storage_provider == "s3": """Handles uploading of the file to local storage."""
try: file_path = f"{UPLOAD_DIR}/{filename}"
bucket = self.bucket_name with open(file_path, "wb") as f:
response = self.client.list_objects_v2(Bucket=bucket) f.write(contents)
if "Contents" in response: return contents, file_path
for content in response["Contents"]:
self.client.delete_object(Bucket=bucket, Key=content["Key"]) def _get_file_from_s3(self, file_path: str) -> str:
except ClientError as e: """Handles downloading of the file from S3 storage."""
raise RuntimeError(f"Error deleting all files: {e}") if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def _get_file_from_local(self, file_path: str) -> str:
"""Handles downloading of the file from local storage."""
return file_path
def _delete_from_s3(self, filename: str) -> None:
"""Handles deletion of the file from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
def _delete_from_local(self, filename: str) -> None:
"""Handles deletion of the file from local storage."""
file_path = f"{UPLOAD_DIR}/{filename}"
if os.path.isfile(file_path):
os.remove(file_path)
else: else:
raise FileNotFoundError(f"File {filename} not found in local storage.")
def _delete_all_from_s3(self) -> None:
"""Handles deletion of all files from S3 storage."""
if not self.s3_client:
raise RuntimeError("S3 Client is not initialized.")
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
except ClientError as e:
raise RuntimeError(f"Error deleting all files from S3: {e}")
def _delete_all_from_local(self) -> None:
"""Handles deletion of all files from local storage."""
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR): for filename in os.listdir(UPLOAD_DIR):
file_path = os.path.join(UPLOAD_DIR, filename) file_path = os.path.join(UPLOAD_DIR, filename)
if os.path.isfile(file_path): try:
os.remove(file_path) if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
raise FileNotFoundError(
f"Directory {UPLOAD_DIR} not found in local storage."
)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Uploads a file either to S3 or the local file system."""
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
if self.storage_provider == "s3":
return self._upload_to_s3(file, filename)
return self._upload_to_local(contents, filename)
def get_file(self, file_path: str) -> str:
"""Downloads a file either from S3 or the local file system and returns the file path."""
if self.storage_provider == "s3":
return self._get_file_from_s3(file_path)
return self._get_file_from_local(file_path)
def delete_file(self, filename: str) -> None:
"""Deletes a file either from S3 or the local file system."""
if self.storage_provider == "s3":
self._delete_from_s3(filename)
else:
self._delete_from_local(filename)
def delete_all_files(self) -> None:
"""Deletes all files from the storage."""
if self.storage_provider == "s3":
self._delete_all_from_s3()
else:
self._delete_all_from_local()
Storage = StorageProvider() Storage = StorageProvider(provider=STORAGE_PROVIDER)