Fix potential path traversal on file download (CVE-2024-24591)

This commit is contained in:
allegroai 2024-01-30 19:28:31 +02:00
parent 1b77e08ec8
commit 831c1394da
2 changed files with 33 additions and 2 deletions

View File

@ -29,6 +29,7 @@ from ..storage.util import sha256sum, is_windows, md5text, format_size
from ..utilities.matching import matches_any_wildcard
from ..utilities.parallel import ParallelZipper
from ..utilities.version import Version
from ..utilities.files import is_path_traversal
try:
from pathlib import Path as _Path # noqa
@ -1856,6 +1857,12 @@ class Dataset(object):
for ds in datasets:
base_folder = Path(ds._get_dataset_files())
files = [f.relative_path for f in ds.file_entries if f.parent_dataset_id == ds.id]
files = [
os.path.basename(file)
if is_path_traversal(base_folder, file) or is_path_traversal(temp_folder, file)
else file
for file in files
]
pool.map(
lambda x:
(temp_folder / x).parent.mkdir(parents=True, exist_ok=True) or
@ -2326,12 +2333,24 @@ class Dataset(object):
link.size = Path(target_path).stat().st_size
if not max_workers:
for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path)
if not is_path_traversal(target_folder, relative_path):
target_path = os.path.join(target_folder, relative_path)
else:
LoggerRoot.get_base_logger().warning(
"Ignoring relative path `{}`: it must not traverse directories".format(relative_path)
)
target_path = os.path.join(target_folder, os.path.basename(relative_path))
_download_link(link, target_path)
else:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path)
if not is_path_traversal(target_folder, relative_path):
target_path = os.path.join(target_folder, relative_path)
else:
LoggerRoot.get_base_logger().warning(
"Ignoring relative path `{}`: it must not traverse directories".format(relative_path)
)
target_path = os.path.join(target_folder, os.path.basename(relative_path))
pool.submit(_download_link, link, target_path)
def _extract_dataset_archive(

View File

@ -20,3 +20,15 @@ def get_filename_max_length(dir_path):
print(err)
return 255 # Common filesystems like NTFS, EXT4 and HFS+ limited with 255
def is_path_traversal(target_folder, relative_path):
try:
target_folder = pathlib2.Path(target_folder)
relative_path = pathlib2.Path(relative_path)
# returns the relative path starting from the target_folder,
# or raise an ValueError if a directory traversal attack is tried
target_folder.joinpath(relative_path).resolve().relative_to(target_folder.resolve())
return False
except ValueError:
return True