mirror of
https://github.com/clearml/clearml
synced 2025-04-14 12:31:47 +00:00
Refactor inner functions
This commit is contained in:
parent
8c776b6da0
commit
ff73e7848c
clearml/storage
@ -14,7 +14,7 @@ from pathlib2 import Path
|
|||||||
|
|
||||||
from .cache import CacheManager
|
from .cache import CacheManager
|
||||||
from .helper import StorageHelper
|
from .helper import StorageHelper
|
||||||
from .util import encode_string_to_filename
|
from .util import encode_string_to_filename, safe_extract
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
|
|
||||||
|
|
||||||
@ -165,47 +165,9 @@ class StorageManager(object):
|
|||||||
ZipFile(cached_file.as_posix()).extractall(path=temp_target_folder.as_posix())
|
ZipFile(cached_file.as_posix()).extractall(path=temp_target_folder.as_posix())
|
||||||
elif suffix == ".tar.gz":
|
elif suffix == ".tar.gz":
|
||||||
with tarfile.open(cached_file.as_posix()) as file:
|
with tarfile.open(cached_file.as_posix()) as file:
|
||||||
def is_within_directory(directory, target):
|
|
||||||
|
|
||||||
abs_directory = os.path.abspath(directory)
|
|
||||||
abs_target = os.path.abspath(target)
|
|
||||||
|
|
||||||
prefix = os.path.commonprefix([abs_directory, abs_target])
|
|
||||||
|
|
||||||
return prefix == abs_directory
|
|
||||||
|
|
||||||
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
|
||||||
|
|
||||||
for member in tar.getmembers():
|
|
||||||
member_path = os.path.join(path, member.name)
|
|
||||||
if not is_within_directory(path, member_path):
|
|
||||||
raise Exception("Attempted Path Traversal in Tar File")
|
|
||||||
|
|
||||||
tar.extractall(path, members, numeric_owner=numeric_owner)
|
|
||||||
|
|
||||||
|
|
||||||
safe_extract(file, temp_target_folder.as_posix())
|
safe_extract(file, temp_target_folder.as_posix())
|
||||||
elif suffix == ".tgz":
|
elif suffix == ".tgz":
|
||||||
with tarfile.open(cached_file.as_posix(), mode='r:gz') as file:
|
with tarfile.open(cached_file.as_posix(), mode='r:gz') as file:
|
||||||
def is_within_directory(directory, target):
|
|
||||||
|
|
||||||
abs_directory = os.path.abspath(directory)
|
|
||||||
abs_target = os.path.abspath(target)
|
|
||||||
|
|
||||||
prefix = os.path.commonprefix([abs_directory, abs_target])
|
|
||||||
|
|
||||||
return prefix == abs_directory
|
|
||||||
|
|
||||||
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
|
||||||
|
|
||||||
for member in tar.getmembers():
|
|
||||||
member_path = os.path.join(path, member.name)
|
|
||||||
if not is_within_directory(path, member_path):
|
|
||||||
raise Exception("Attempted Path Traversal in Tar File")
|
|
||||||
|
|
||||||
tar.extractall(path, members, numeric_owner=numeric_owner)
|
|
||||||
|
|
||||||
|
|
||||||
safe_extract(file, temp_target_folder.as_posix())
|
safe_extract(file, temp_target_folder.as_posix())
|
||||||
|
|
||||||
if temp_target_folder != target_folder:
|
if temp_target_folder != target_folder:
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
|
import fnmatch
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import os.path
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from zlib import crc32
|
|
||||||
from typing import Optional, Union, Sequence, Dict
|
from typing import Optional, Union, Sequence, Dict
|
||||||
from pathlib2 import Path
|
from zlib import crc32
|
||||||
|
|
||||||
from six.moves.urllib.parse import quote, urlparse, urlunparse
|
|
||||||
import six
|
import six
|
||||||
import fnmatch
|
from pathlib2 import Path
|
||||||
|
from six.moves.urllib.parse import quote, urlparse, urlunparse
|
||||||
|
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
|
|
||||||
@ -16,11 +17,13 @@ from ..debugging.log import LoggerRoot
|
|||||||
def get_config_object_matcher(**patterns):
|
def get_config_object_matcher(**patterns):
|
||||||
unsupported = {k: v for k, v in patterns.items() if not isinstance(v, six.string_types)}
|
unsupported = {k: v for k, v in patterns.items() if not isinstance(v, six.string_types)}
|
||||||
if unsupported:
|
if unsupported:
|
||||||
raise ValueError('Unsupported object matcher (expecting string): %s'
|
raise ValueError(
|
||||||
% ', '.join('%s=%s' % (k, v) for k, v in unsupported.items()))
|
"Unsupported object matcher (expecting string): %s"
|
||||||
|
% ", ".join("%s=%s" % (k, v) for k, v in unsupported.items())
|
||||||
|
)
|
||||||
|
|
||||||
# optimize simple patters
|
# optimize simple patters
|
||||||
starts_with = {k: v.rstrip('*') for k, v in patterns.items() if '*' not in v.rstrip('*') and '?' not in v}
|
starts_with = {k: v.rstrip("*") for k, v in patterns.items() if "*" not in v.rstrip("*") and "?" not in v}
|
||||||
patterns = {k: v for k, v in patterns.items() if v not in starts_with}
|
patterns = {k: v for k, v in patterns.items() if v not in starts_with}
|
||||||
|
|
||||||
def _matcher(**kwargs):
|
def _matcher(**kwargs):
|
||||||
@ -60,7 +63,7 @@ def sha256sum(filename, skip_header=0, block_size=65536):
|
|||||||
b = bytearray(block_size)
|
b = bytearray(block_size)
|
||||||
mv = memoryview(b)
|
mv = memoryview(b)
|
||||||
try:
|
try:
|
||||||
with open(filename, 'rb', buffering=0) as f:
|
with open(filename, "rb", buffering=0) as f:
|
||||||
# skip header
|
# skip header
|
||||||
if skip_header:
|
if skip_header:
|
||||||
file_hash.update(f.read(skip_header))
|
file_hash.update(f.read(skip_header))
|
||||||
@ -86,7 +89,7 @@ def md5text(text, seed=1337):
|
|||||||
:param seed: use prefix seed for hashing
|
:param seed: use prefix seed for hashing
|
||||||
:return: md5 string
|
:return: md5 string
|
||||||
"""
|
"""
|
||||||
return hash_text(text=text, seed=seed, hash_func='md5')
|
return hash_text(text=text, seed=seed, hash_func="md5")
|
||||||
|
|
||||||
|
|
||||||
def crc32text(text, seed=1337):
|
def crc32text(text, seed=1337):
|
||||||
@ -99,10 +102,10 @@ def crc32text(text, seed=1337):
|
|||||||
:param seed: use prefix seed for hashing
|
:param seed: use prefix seed for hashing
|
||||||
:return: crc32 hex in string (32bits = 8 characters in hex)
|
:return: crc32 hex in string (32bits = 8 characters in hex)
|
||||||
"""
|
"""
|
||||||
return '{:08x}'.format(crc32((str(seed)+str(text)).encode('utf-8')))
|
return "{:08x}".format(crc32((str(seed) + str(text)).encode("utf-8")))
|
||||||
|
|
||||||
|
|
||||||
def hash_text(text, seed=1337, hash_func='md5'):
|
def hash_text(text, seed=1337, hash_func="md5"):
|
||||||
# type: (str, Union[int, str], str) -> str
|
# type: (str, Union[int, str], str) -> str
|
||||||
"""
|
"""
|
||||||
Return hash_func (md5/sha1/sha256/sha384/sha512) hash of a string
|
Return hash_func (md5/sha1/sha256/sha384/sha512) hash of a string
|
||||||
@ -112,13 +115,13 @@ def hash_text(text, seed=1337, hash_func='md5'):
|
|||||||
:param hash_func: hashing function. currently supported md5 sha256
|
:param hash_func: hashing function. currently supported md5 sha256
|
||||||
:return: hashed string
|
:return: hashed string
|
||||||
"""
|
"""
|
||||||
assert hash_func in ('md5', 'sha256', 'sha256', 'sha384', 'sha512')
|
assert hash_func in ("md5", "sha256", "sha256", "sha384", "sha512")
|
||||||
h = getattr(hashlib, hash_func)()
|
h = getattr(hashlib, hash_func)()
|
||||||
h.update((str(seed) + str(text)).encode('utf-8'))
|
h.update((str(seed) + str(text)).encode("utf-8"))
|
||||||
return h.hexdigest()
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def hash_dict(a_dict, seed=1337, hash_func='md5'):
|
def hash_dict(a_dict, seed=1337, hash_func="md5"):
|
||||||
# type: (Dict, Union[int, str], str) -> str
|
# type: (Dict, Union[int, str], str) -> str
|
||||||
"""
|
"""
|
||||||
Return hash_func (crc32/md5/sha1/sha256/sha384/sha512) hash of the dict values
|
Return hash_func (crc32/md5/sha1/sha256/sha384/sha512) hash of the dict values
|
||||||
@ -129,9 +132,9 @@ def hash_dict(a_dict, seed=1337, hash_func='md5'):
|
|||||||
:param hash_func: hashing function. currently supported md5 sha256
|
:param hash_func: hashing function. currently supported md5 sha256
|
||||||
:return: hashed string
|
:return: hashed string
|
||||||
"""
|
"""
|
||||||
assert hash_func in ('crc32', 'md5', 'sha256', 'sha256', 'sha384', 'sha512')
|
assert hash_func in ("crc32", "md5", "sha256", "sha256", "sha384", "sha512")
|
||||||
repr_string = json.dumps(a_dict, sort_keys=True)
|
repr_string = json.dumps(a_dict, sort_keys=True)
|
||||||
if hash_func == 'crc32':
|
if hash_func == "crc32":
|
||||||
return crc32text(repr_string, seed=seed)
|
return crc32text(repr_string, seed=seed)
|
||||||
else:
|
else:
|
||||||
return hash_text(repr_string, seed=seed, hash_func=hash_func)
|
return hash_text(repr_string, seed=seed, hash_func=hash_func)
|
||||||
@ -141,7 +144,7 @@ def is_windows():
|
|||||||
"""
|
"""
|
||||||
:return: True if currently running on windows OS
|
:return: True if currently running on windows OS
|
||||||
"""
|
"""
|
||||||
return sys.platform == 'win32'
|
return sys.platform == "win32"
|
||||||
|
|
||||||
|
|
||||||
def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b_instead_of_bytes=False):
|
def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b_instead_of_bytes=False):
|
||||||
@ -185,7 +188,7 @@ def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b
|
|||||||
for i, m in enumerate(scale):
|
for i, m in enumerate(scale):
|
||||||
if size < k ** (i + 1) or i == len(scale) - 1:
|
if size < k ** (i + 1) or i == len(scale) - 1:
|
||||||
return (
|
return (
|
||||||
("{:.2f}".format(size / (k ** i)).rstrip("0").rstrip(".") if i > 0 else "{}".format(int(size)))
|
("{:.2f}".format(size / (k**i)).rstrip("0").rstrip(".") if i > 0 else "{}".format(int(size)))
|
||||||
+ " "
|
+ " "
|
||||||
+ m
|
+ m
|
||||||
)
|
)
|
||||||
@ -226,42 +229,53 @@ def parse_size(size, binary=False):
|
|||||||
>>> parse_size('1.5 GB', binary=True)
|
>>> parse_size('1.5 GB', binary=True)
|
||||||
1610612736
|
1610612736
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokenize(text):
|
def tokenize(text):
|
||||||
tokenized_input = []
|
tokenized_input = []
|
||||||
for token in re.split(r'(\d+(?:\.\d+)?)', text):
|
for token in re.split(r"(\d+(?:\.\d+)?)", text):
|
||||||
token = token.strip()
|
token = token.strip()
|
||||||
if re.match(r'\d+\.\d+', token):
|
if re.match(r"\d+\.\d+", token):
|
||||||
tokenized_input.append(float(token))
|
tokenized_input.append(float(token))
|
||||||
elif token.isdigit():
|
elif token.isdigit():
|
||||||
tokenized_input.append(int(token))
|
tokenized_input.append(int(token))
|
||||||
elif token:
|
elif token:
|
||||||
tokenized_input.append(token)
|
tokenized_input.append(token)
|
||||||
return tokenized_input
|
return tokenized_input
|
||||||
|
|
||||||
tokens = tokenize(str(size))
|
tokens = tokenize(str(size))
|
||||||
if tokens and isinstance(tokens[0], (int, float)):
|
if tokens and isinstance(tokens[0], (int, float)):
|
||||||
disk_size_units_b = \
|
disk_size_units_b = (
|
||||||
(('B', 'bytes'), ('KiB', 'kibibyte'), ('MiB', 'mebibyte'), ('GiB', 'gibibyte'),
|
("B", "bytes"),
|
||||||
('TiB', 'tebibyte'), ('PiB', 'pebibyte'))
|
("KiB", "kibibyte"),
|
||||||
disk_size_units_d = \
|
("MiB", "mebibyte"),
|
||||||
(('B', 'bytes'), ('KB', 'kilobyte'), ('MB', 'megabyte'), ('GB', 'gigabyte'),
|
("GiB", "gibibyte"),
|
||||||
('TB', 'terabyte'), ('PB', 'petabyte'))
|
("TiB", "tebibyte"),
|
||||||
disk_size_units_b = [(1024 ** i, s[0], s[1]) for i, s in enumerate(disk_size_units_b)]
|
("PiB", "pebibyte"),
|
||||||
|
)
|
||||||
|
disk_size_units_d = (
|
||||||
|
("B", "bytes"),
|
||||||
|
("KB", "kilobyte"),
|
||||||
|
("MB", "megabyte"),
|
||||||
|
("GB", "gigabyte"),
|
||||||
|
("TB", "terabyte"),
|
||||||
|
("PB", "petabyte"),
|
||||||
|
)
|
||||||
|
disk_size_units_b = [(1024**i, s[0], s[1]) for i, s in enumerate(disk_size_units_b)]
|
||||||
k = 1024 if binary else 1000
|
k = 1024 if binary else 1000
|
||||||
disk_size_units_d = [(k ** i, s[0], s[1]) for i, s in enumerate(disk_size_units_d)]
|
disk_size_units_d = [(k**i, s[0], s[1]) for i, s in enumerate(disk_size_units_d)]
|
||||||
disk_size_units = (disk_size_units_b + disk_size_units_d) \
|
disk_size_units = (disk_size_units_b + disk_size_units_d) if binary else (disk_size_units_d + disk_size_units_b)
|
||||||
if binary else (disk_size_units_d + disk_size_units_b)
|
|
||||||
|
|
||||||
# Get the normalized unit (if any) from the tokenized input.
|
# Get the normalized unit (if any) from the tokenized input.
|
||||||
normalized_unit = tokens[1].lower() if len(tokens) == 2 and isinstance(tokens[1], str) else ''
|
normalized_unit = tokens[1].lower() if len(tokens) == 2 and isinstance(tokens[1], str) else ""
|
||||||
# If the input contains only a number, it's assumed to be the number of
|
# If the input contains only a number, it's assumed to be the number of
|
||||||
# bytes. The second token can also explicitly reference the unit bytes.
|
# bytes. The second token can also explicitly reference the unit bytes.
|
||||||
if len(tokens) == 1 or normalized_unit.startswith('b'):
|
if len(tokens) == 1 or normalized_unit.startswith("b"):
|
||||||
return int(tokens[0])
|
return int(tokens[0])
|
||||||
# Otherwise we expect two tokens: A number and a unit.
|
# Otherwise we expect two tokens: A number and a unit.
|
||||||
if normalized_unit:
|
if normalized_unit:
|
||||||
# Convert plural units to singular units, for details:
|
# Convert plural units to singular units, for details:
|
||||||
# https://github.com/xolox/python-humanfriendly/issues/26
|
# https://github.com/xolox/python-humanfriendly/issues/26
|
||||||
normalized_unit = normalized_unit.rstrip('s')
|
normalized_unit = normalized_unit.rstrip("s")
|
||||||
for k, low, high in disk_size_units:
|
for k, low, high in disk_size_units:
|
||||||
# First we check for unambiguous symbols (KiB, MiB, GiB, etc)
|
# First we check for unambiguous symbols (KiB, MiB, GiB, etc)
|
||||||
# and names (kibibyte, mebibyte, gibibyte, etc) because their
|
# and names (kibibyte, mebibyte, gibibyte, etc) because their
|
||||||
@ -271,8 +285,7 @@ def parse_size(size, binary=False):
|
|||||||
# Now we will deal with ambiguous prefixes (K, M, G, etc),
|
# Now we will deal with ambiguous prefixes (K, M, G, etc),
|
||||||
# symbols (KB, MB, GB, etc) and names (kilobyte, megabyte,
|
# symbols (KB, MB, GB, etc) and names (kilobyte, megabyte,
|
||||||
# gigabyte, etc) according to the caller's preference.
|
# gigabyte, etc) according to the caller's preference.
|
||||||
if (normalized_unit in (low.lower(), high.lower()) or
|
if normalized_unit in (low.lower(), high.lower()) or normalized_unit.startswith(low.lower()):
|
||||||
normalized_unit.startswith(low.lower())):
|
|
||||||
return int(tokens[0] * k)
|
return int(tokens[0] * k)
|
||||||
|
|
||||||
raise ValueError("Failed to parse size! (input {} was tokenized as {})".format(size, tokens))
|
raise ValueError("Failed to parse size! (input {} was tokenized as {})".format(size, tokens))
|
||||||
@ -301,8 +314,7 @@ def get_common_path(list_of_files):
|
|||||||
if f_parts[:num_p] == common_path_parts[:num_p]:
|
if f_parts[:num_p] == common_path_parts[:num_p]:
|
||||||
common_path_parts = common_path_parts[:num_p]
|
common_path_parts = common_path_parts[:num_p]
|
||||||
continue
|
continue
|
||||||
num_p = min(
|
num_p = min([i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1])
|
||||||
[i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1])
|
|
||||||
# no common path, break
|
# no common path, break
|
||||||
if num_p < 0:
|
if num_p < 0:
|
||||||
common_path_parts = []
|
common_path_parts = []
|
||||||
@ -317,3 +329,19 @@ def get_common_path(list_of_files):
|
|||||||
return common_path.as_posix()
|
return common_path.as_posix()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_within_directory(directory, target):
|
||||||
|
abs_directory = os.path.abspath(directory)
|
||||||
|
abs_target = os.path.abspath(target)
|
||||||
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||||||
|
return prefix == abs_directory
|
||||||
|
|
||||||
|
|
||||||
|
def safe_extract(tar, path=".", members=None, numeric_owner=False):
|
||||||
|
"""Tarfile member sanitization (addresses CVE-2007-4559)"""
|
||||||
|
for member in tar.getmembers():
|
||||||
|
member_path = os.path.join(path, member.name)
|
||||||
|
if not is_within_directory(path, member_path):
|
||||||
|
raise Exception("Attempted Path Traversal in Tar File")
|
||||||
|
tar.extractall(path, members, numeric_owner=numeric_owner)
|
||||||
|
Loading…
Reference in New Issue
Block a user