Fix dataset zip extraction might fail when creating folders concurrently

This commit is contained in:
allegroai 2023-09-12 00:50:23 +03:00
parent d458924160
commit 96b89d76b8
2 changed files with 36 additions and 3 deletions

View File

@ -12,7 +12,7 @@ from pathlib2 import Path
from .cache import CacheManager
from .callbacks import ProgressReport
from .helper import StorageHelper
from .util import encode_string_to_filename, safe_extract
from .util import encode_string_to_filename, safe_extract, create_zip_directories
from ..debugging.log import LoggerRoot
from ..config import deferred_config
@ -163,7 +163,9 @@ class StorageManager(object):
temp_target_folder.mkdir(parents=True, exist_ok=True)
if suffix == ".zip":
ZipFile(cached_file.as_posix()).extractall(path=temp_target_folder.as_posix())
zip_file = ZipFile(cached_file.as_posix())
create_zip_directories(zip_file, path=temp_target_folder.as_posix())
zip_file.extractall(path=temp_target_folder.as_posix())
elif suffix == ".tar.gz":
with tarfile.open(cached_file.as_posix()) as file:
safe_extract(file, temp_target_folder.as_posix())

View File

@ -1,7 +1,7 @@
import fnmatch
import hashlib
import json
import os.path
import os
import re
import sys
from typing import Optional, Union, Sequence, Dict
@ -338,6 +338,37 @@ def is_within_directory(directory, target):
return prefix == abs_directory
def create_zip_directories(zipfile, path=None):
try:
path = os.getcwd() if path is None else os.fspath(path)
for member in zipfile.namelist():
arcname = member.replace("/", os.path.sep)
if os.path.altsep:
arcname = arcname.replace(os.path.altsep, os.path.sep)
# interpret absolute pathname as relative, remove drive letter or
# UNC path, redundant separators, "." and ".." components.
arcname = os.path.splitdrive(arcname)[1]
invalid_path_parts = ("", os.path.curdir, os.path.pardir)
arcname = os.path.sep.join(x for x in arcname.split(os.path.sep) if x not in invalid_path_parts)
if os.path.sep == "\\":
# noinspection PyBroadException
try:
# filter illegal characters on Windows
# noinspection PyProtectedMember
arcname = zipfile._sanitize_windows_name(arcname, os.path.sep)
except Exception:
pass
targetpath = os.path.normpath(os.path.join(path, arcname))
# Create all upper directories if necessary.
upperdirs = os.path.dirname(targetpath)
if upperdirs:
os.makedirs(upperdirs, exist_ok=True)
except Exception as e:
LoggerRoot.get_base_logger().warning("Failed creating zip directories: " + str(e))
def safe_extract(tar, path=".", members=None, numeric_owner=False):
"""Tarfile member sanitization (addresses CVE-2007-4559)"""
for member in tar.getmembers():