diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index e7433f649..671429bb5 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -448,25 +448,11 @@ def store_doc( log.info(f"file.content_type: {file.content_type}") try: - is_valid_filename = True unsanitized_filename = file.filename - if re.search(r'[\\/:"\*\?<>|\n\t ]', unsanitized_filename) is not None: - is_valid_filename = False + filename = os.path.basename(unsanitized_filename) - unvalidated_file_path = f"{UPLOAD_DIR}/{unsanitized_filename}" - dereferenced_file_path = str(Path(unvalidated_file_path).resolve(strict=False)) - if not dereferenced_file_path.startswith(UPLOAD_DIR): - is_valid_filename = False + file_path = f"{UPLOAD_DIR}/{filename}" - if is_valid_filename: - file_path = dereferenced_file_path - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - filename = file.filename contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) @@ -477,7 +463,7 @@ def store_doc( collection_name = calculate_sha256(f)[:63] f.close() - loader, known_type = get_loader(file.filename, file.content_type, file_path) + loader, known_type = get_loader(filename, file.content_type, file_path) data = loader.load() try: