Fix dataset zip extraction might fail when creating folders concurrently

This commit is contained in:
allegroai 2023-09-12 00:50:23 +03:00
parent d458924160
commit 96b89d76b8
2 changed files with 36 additions and 3 deletions

View File

@ -12,7 +12,7 @@ from pathlib2 import Path
from .cache import CacheManager from .cache import CacheManager
from .callbacks import ProgressReport from .callbacks import ProgressReport
from .helper import StorageHelper from .helper import StorageHelper
from .util import encode_string_to_filename, safe_extract from .util import encode_string_to_filename, safe_extract, create_zip_directories
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..config import deferred_config from ..config import deferred_config
@ -163,7 +163,9 @@ class StorageManager(object):
temp_target_folder.mkdir(parents=True, exist_ok=True) temp_target_folder.mkdir(parents=True, exist_ok=True)
if suffix == ".zip": if suffix == ".zip":
ZipFile(cached_file.as_posix()).extractall(path=temp_target_folder.as_posix()) zip_file = ZipFile(cached_file.as_posix())
create_zip_directories(zip_file, path=temp_target_folder.as_posix())
zip_file.extractall(path=temp_target_folder.as_posix())
elif suffix == ".tar.gz": elif suffix == ".tar.gz":
with tarfile.open(cached_file.as_posix()) as file: with tarfile.open(cached_file.as_posix()) as file:
safe_extract(file, temp_target_folder.as_posix()) safe_extract(file, temp_target_folder.as_posix())

View File

@ -1,7 +1,7 @@
import fnmatch import fnmatch
import hashlib import hashlib
import json import json
import os.path import os
import re import re
import sys import sys
from typing import Optional, Union, Sequence, Dict from typing import Optional, Union, Sequence, Dict
@ -338,6 +338,37 @@ def is_within_directory(directory, target):
return prefix == abs_directory return prefix == abs_directory
def create_zip_directories(zipfile, path=None):
try:
path = os.getcwd() if path is None else os.fspath(path)
for member in zipfile.namelist():
arcname = member.replace("/", os.path.sep)
if os.path.altsep:
arcname = arcname.replace(os.path.altsep, os.path.sep)
# interpret absolute pathname as relative, remove drive letter or
# UNC path, redundant separators, "." and ".." components.
arcname = os.path.splitdrive(arcname)[1]
invalid_path_parts = ("", os.path.curdir, os.path.pardir)
arcname = os.path.sep.join(x for x in arcname.split(os.path.sep) if x not in invalid_path_parts)
if os.path.sep == "\\":
# noinspection PyBroadException
try:
# filter illegal characters on Windows
# noinspection PyProtectedMember
arcname = zipfile._sanitize_windows_name(arcname, os.path.sep)
except Exception:
pass
targetpath = os.path.normpath(os.path.join(path, arcname))
# Create all upper directories if necessary.
upperdirs = os.path.dirname(targetpath)
if upperdirs:
os.makedirs(upperdirs, exist_ok=True)
except Exception as e:
LoggerRoot.get_base_logger().warning("Failed creating zip directories: " + str(e))
def safe_extract(tar, path=".", members=None, numeric_owner=False): def safe_extract(tar, path=".", members=None, numeric_owner=False):
"""Tarfile member sanitization (addresses CVE-2007-4559)""" """Tarfile member sanitization (addresses CVE-2007-4559)"""
for member in tar.getmembers(): for member in tar.getmembers():