diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5aae07e3c..2f0d99261 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1951,6 +1951,12 @@ RAG_FILE_MAX_SIZE = PersistentConfig( ), ) +RAG_ALLOWED_FILE_EXTENSIONS = PersistentConfig( + "RAG_ALLOWED_FILE_EXTENSIONS", + "rag.file.allowed_extensions", + os.environ.get("RAG_ALLOWED_FILE_EXTENSIONS", "").split(","), +) + RAG_EMBEDDING_ENGINE = PersistentConfig( "RAG_EMBEDDING_ENGINE", "rag.embedding_engine", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 3d1036785..46f72a111 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -197,6 +197,7 @@ from open_webui.config import ( RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_BATCH_SIZE, RAG_RELEVANCE_THRESHOLD, + RAG_ALLOWED_FILE_EXTENSIONS, RAG_FILE_MAX_COUNT, RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, @@ -638,6 +639,7 @@ app.state.FUNCTIONS = {} app.state.config.TOP_K = RAG_TOP_K app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index ba589070f..ec76d7e63 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -95,6 +95,16 @@ def upload_file( unsanitized_filename = file.filename filename = os.path.basename(unsanitized_filename) + file_extension = os.path.splitext(filename)[1] + if request.app.state.config.ALLOWED_FILE_EXTENSIONS: + if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT( + f"File type {file_extension} is not allowed" + ), + ) + # replace filename with uuid id = str(uuid.uuid4()) name = filename