mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	added alternative storage client instantiation method, corrected filepaths, added missing type hinting
This commit is contained in:
		
							parent
							
								
									1764de41f3
								
							
						
					
					
						commit
						49f31ddcd8
					
				@ -590,6 +590,7 @@ S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
 | 
			
		||||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
 | 
			
		||||
 | 
			
		||||
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
 | 
			
		||||
GCS_PROJECT_ID = os.environ.get("GCS_PROJECT_ID", None)
 | 
			
		||||
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON", None)
 | 
			
		||||
 | 
			
		||||
####################################
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@ import shutil
 | 
			
		||||
import json
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import BinaryIO, Tuple
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
 | 
			
		||||
import boto3
 | 
			
		||||
from botocore.exceptions import ClientError
 | 
			
		||||
@ -14,6 +13,7 @@ from open_webui.config import (
 | 
			
		||||
    S3_REGION_NAME,
 | 
			
		||||
    S3_SECRET_ACCESS_KEY,
 | 
			
		||||
    GCS_BUCKET_NAME,
 | 
			
		||||
    GCS_PROJECT_ID, 
 | 
			
		||||
    GOOGLE_APPLICATION_CREDENTIALS_JSON,
 | 
			
		||||
    STORAGE_PROVIDER,
 | 
			
		||||
    UPLOAD_DIR,
 | 
			
		||||
@ -145,41 +145,41 @@ class S3StorageProvider(StorageProvider):
 | 
			
		||||
 | 
			
		||||
class GCSStorageProvider(StorageProvider):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.gcs_client = storage.Client.from_service_account_info(info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON))
 | 
			
		||||
        self.bucket_name = self.gcs_client.bucket(GCS_BUCKET_NAME)
 | 
			
		||||
        if GCS_PROJECT_ID:
 | 
			
		||||
            self.gcs_client = storage.Client(project=GCS_PROJECT_ID)
 | 
			
		||||
        if GOOGLE_APPLICATION_CREDENTIALS_JSON:
 | 
			
		||||
            self.gcs_client = storage.Client.from_service_account_info(info=json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON))
 | 
			
		||||
 | 
			
		||||
        self.bucket_name = GCS_BUCKET_NAME
 | 
			
		||||
        self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
 | 
			
		||||
    
 | 
			
		||||
    def upload_file(self, file: BinaryIO, filename: str):
 | 
			
		||||
    def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
 | 
			
		||||
        """Handles uploading of the file to GCS storage."""
 | 
			
		||||
        contents, _ = LocalStorageProvider.upload_file(file, filename)
 | 
			
		||||
        contents, file_path = LocalStorageProvider.upload_file(file, filename)
 | 
			
		||||
        try:
 | 
			
		||||
            # Get the blob (object in the bucket)
 | 
			
		||||
            blob = self.bucket_name.blob(filename)
 | 
			
		||||
            # Upload the file to the bucket
 | 
			
		||||
            blob.upload_from_file(BytesIO(contents))
 | 
			
		||||
            return contents, _
 | 
			
		||||
            blob = self.bucket.blob(filename)
 | 
			
		||||
            blob.upload_from_filename(file_path)
 | 
			
		||||
            return contents, "gs://" + self.bucket_name + "/" + filename
 | 
			
		||||
        except GoogleCloudError as e:
 | 
			
		||||
            raise RuntimeError(f"Error uploading file to GCS: {e}")
 | 
			
		||||
 | 
			
		||||
    def get_file(self, file_path:str) -> str:
 | 
			
		||||
        """Handles downloading of the file from GCS storage."""
 | 
			
		||||
        try:
 | 
			
		||||
            local_file_path=file_path.removeprefix(UPLOAD_DIR + "/")
 | 
			
		||||
            # Get the blob (object in the bucket)
 | 
			
		||||
            blob = self.bucket_name.blob(local_file_path)
 | 
			
		||||
            # Download the file to a local destination
 | 
			
		||||
            blob.download_to_filename(file_path)
 | 
			
		||||
            return file_path
 | 
			
		||||
            filename = file_path.removeprefix("gs://").split("/")[1]
 | 
			
		||||
            local_file_path = f"{UPLOAD_DIR}/{filename}"            
 | 
			
		||||
            blob = self.bucket.blob(filename)
 | 
			
		||||
            blob.download_to_filename(local_file_path)
 | 
			
		||||
 | 
			
		||||
            return local_file_path
 | 
			
		||||
        except NotFound as e:
 | 
			
		||||
            raise RuntimeError(f"Error downloading file from GCS: {e}")
 | 
			
		||||
    
 | 
			
		||||
    def delete_file(self, file_path:str) -> None:
 | 
			
		||||
        """Handles deletion of the file from GCS storage."""
 | 
			
		||||
        try:
 | 
			
		||||
            local_file_path = file_path.removeprefix(UPLOAD_DIR + "/")
 | 
			
		||||
            # Get the blob (object in the bucket)
 | 
			
		||||
            blob = self.bucket_name.blob(local_file_path)
 | 
			
		||||
 | 
			
		||||
            # Delete the file
 | 
			
		||||
            filename = file_path.removeprefix("gs://").split("/")[1]
 | 
			
		||||
            blob = self.bucket.blob(filename)
 | 
			
		||||
            blob.delete()
 | 
			
		||||
        except NotFound as e:
 | 
			
		||||
            raise RuntimeError(f"Error deleting file from GCS: {e}")
 | 
			
		||||
@ -190,10 +190,8 @@ class GCSStorageProvider(StorageProvider):
 | 
			
		||||
    def delete_all_files(self) -> None:
 | 
			
		||||
        """Handles deletion of all files from GCS storage."""
 | 
			
		||||
        try:
 | 
			
		||||
            # List all objects in the bucket
 | 
			
		||||
            blobs = self.bucket_name.list_blobs()
 | 
			
		||||
            blobs = self.bucket.list_blobs()
 | 
			
		||||
 | 
			
		||||
            # Delete all files
 | 
			
		||||
            for blob in blobs:
 | 
			
		||||
                blob.delete()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user