Refactor inner functions

This commit is contained in:
allegroai 2022-10-30 19:29:01 +02:00
parent 8c776b6da0
commit ff73e7848c
2 changed files with 66 additions and 76 deletions

View File

@ -14,7 +14,7 @@ from pathlib2 import Path
from .cache import CacheManager
from .helper import StorageHelper
from .util import encode_string_to_filename
from .util import encode_string_to_filename, safe_extract
from ..debugging.log import LoggerRoot
@ -165,47 +165,9 @@ class StorageManager(object):
ZipFile(cached_file.as_posix()).extractall(path=temp_target_folder.as_posix())
elif suffix == ".tar.gz":
with tarfile.open(cached_file.as_posix()) as file:
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(file, temp_target_folder.as_posix())
elif suffix == ".tgz":
with tarfile.open(cached_file.as_posix(), mode='r:gz') as file:
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(file, temp_target_folder.as_posix())
if temp_target_folder != target_folder:

View File

@ -1,14 +1,15 @@
import fnmatch
import hashlib
import json
import os.path
import re
import sys
from zlib import crc32
from typing import Optional, Union, Sequence, Dict
from pathlib2 import Path
from zlib import crc32
from six.moves.urllib.parse import quote, urlparse, urlunparse
import six
import fnmatch
from pathlib2 import Path
from six.moves.urllib.parse import quote, urlparse, urlunparse
from ..debugging.log import LoggerRoot
@ -16,11 +17,13 @@ from ..debugging.log import LoggerRoot
def get_config_object_matcher(**patterns):
unsupported = {k: v for k, v in patterns.items() if not isinstance(v, six.string_types)}
if unsupported:
raise ValueError('Unsupported object matcher (expecting string): %s'
% ', '.join('%s=%s' % (k, v) for k, v in unsupported.items()))
raise ValueError(
"Unsupported object matcher (expecting string): %s"
% ", ".join("%s=%s" % (k, v) for k, v in unsupported.items())
)
# optimize simple patters
starts_with = {k: v.rstrip('*') for k, v in patterns.items() if '*' not in v.rstrip('*') and '?' not in v}
starts_with = {k: v.rstrip("*") for k, v in patterns.items() if "*" not in v.rstrip("*") and "?" not in v}
patterns = {k: v for k, v in patterns.items() if v not in starts_with}
def _matcher(**kwargs):
@ -60,7 +63,7 @@ def sha256sum(filename, skip_header=0, block_size=65536):
b = bytearray(block_size)
mv = memoryview(b)
try:
with open(filename, 'rb', buffering=0) as f:
with open(filename, "rb", buffering=0) as f:
# skip header
if skip_header:
file_hash.update(f.read(skip_header))
@ -86,7 +89,7 @@ def md5text(text, seed=1337):
:param seed: use prefix seed for hashing
:return: md5 string
"""
return hash_text(text=text, seed=seed, hash_func='md5')
return hash_text(text=text, seed=seed, hash_func="md5")
def crc32text(text, seed=1337):
@ -99,10 +102,10 @@ def crc32text(text, seed=1337):
:param seed: use prefix seed for hashing
:return: crc32 hex in string (32bits = 8 characters in hex)
"""
return '{:08x}'.format(crc32((str(seed)+str(text)).encode('utf-8')))
return "{:08x}".format(crc32((str(seed) + str(text)).encode("utf-8")))
def hash_text(text, seed=1337, hash_func='md5'):
def hash_text(text, seed=1337, hash_func="md5"):
# type: (str, Union[int, str], str) -> str
"""
Return hash_func (md5/sha1/sha256/sha384/sha512) hash of a string
@ -112,13 +115,13 @@ def hash_text(text, seed=1337, hash_func='md5'):
:param hash_func: hashing function. currently supported md5 sha256
:return: hashed string
"""
assert hash_func in ('md5', 'sha256', 'sha256', 'sha384', 'sha512')
assert hash_func in ("md5", "sha256", "sha256", "sha384", "sha512")
h = getattr(hashlib, hash_func)()
h.update((str(seed) + str(text)).encode('utf-8'))
h.update((str(seed) + str(text)).encode("utf-8"))
return h.hexdigest()
def hash_dict(a_dict, seed=1337, hash_func='md5'):
def hash_dict(a_dict, seed=1337, hash_func="md5"):
# type: (Dict, Union[int, str], str) -> str
"""
Return hash_func (crc32/md5/sha1/sha256/sha384/sha512) hash of the dict values
@ -129,9 +132,9 @@ def hash_dict(a_dict, seed=1337, hash_func='md5'):
:param hash_func: hashing function. currently supported md5 sha256
:return: hashed string
"""
assert hash_func in ('crc32', 'md5', 'sha256', 'sha256', 'sha384', 'sha512')
assert hash_func in ("crc32", "md5", "sha256", "sha256", "sha384", "sha512")
repr_string = json.dumps(a_dict, sort_keys=True)
if hash_func == 'crc32':
if hash_func == "crc32":
return crc32text(repr_string, seed=seed)
else:
return hash_text(repr_string, seed=seed, hash_func=hash_func)
@ -141,7 +144,7 @@ def is_windows():
"""
:return: True if currently running on windows OS
"""
return sys.platform == 'win32'
return sys.platform == "win32"
def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b_instead_of_bytes=False):
@ -185,7 +188,7 @@ def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b
for i, m in enumerate(scale):
if size < k ** (i + 1) or i == len(scale) - 1:
return (
("{:.2f}".format(size / (k ** i)).rstrip("0").rstrip(".") if i > 0 else "{}".format(int(size)))
("{:.2f}".format(size / (k**i)).rstrip("0").rstrip(".") if i > 0 else "{}".format(int(size)))
+ " "
+ m
)
@ -226,42 +229,53 @@ def parse_size(size, binary=False):
>>> parse_size('1.5 GB', binary=True)
1610612736
"""
def tokenize(text):
tokenized_input = []
for token in re.split(r'(\d+(?:\.\d+)?)', text):
for token in re.split(r"(\d+(?:\.\d+)?)", text):
token = token.strip()
if re.match(r'\d+\.\d+', token):
if re.match(r"\d+\.\d+", token):
tokenized_input.append(float(token))
elif token.isdigit():
tokenized_input.append(int(token))
elif token:
tokenized_input.append(token)
return tokenized_input
tokens = tokenize(str(size))
if tokens and isinstance(tokens[0], (int, float)):
disk_size_units_b = \
(('B', 'bytes'), ('KiB', 'kibibyte'), ('MiB', 'mebibyte'), ('GiB', 'gibibyte'),
('TiB', 'tebibyte'), ('PiB', 'pebibyte'))
disk_size_units_d = \
(('B', 'bytes'), ('KB', 'kilobyte'), ('MB', 'megabyte'), ('GB', 'gigabyte'),
('TB', 'terabyte'), ('PB', 'petabyte'))
disk_size_units_b = [(1024 ** i, s[0], s[1]) for i, s in enumerate(disk_size_units_b)]
disk_size_units_b = (
("B", "bytes"),
("KiB", "kibibyte"),
("MiB", "mebibyte"),
("GiB", "gibibyte"),
("TiB", "tebibyte"),
("PiB", "pebibyte"),
)
disk_size_units_d = (
("B", "bytes"),
("KB", "kilobyte"),
("MB", "megabyte"),
("GB", "gigabyte"),
("TB", "terabyte"),
("PB", "petabyte"),
)
disk_size_units_b = [(1024**i, s[0], s[1]) for i, s in enumerate(disk_size_units_b)]
k = 1024 if binary else 1000
disk_size_units_d = [(k ** i, s[0], s[1]) for i, s in enumerate(disk_size_units_d)]
disk_size_units = (disk_size_units_b + disk_size_units_d) \
if binary else (disk_size_units_d + disk_size_units_b)
disk_size_units_d = [(k**i, s[0], s[1]) for i, s in enumerate(disk_size_units_d)]
disk_size_units = (disk_size_units_b + disk_size_units_d) if binary else (disk_size_units_d + disk_size_units_b)
# Get the normalized unit (if any) from the tokenized input.
normalized_unit = tokens[1].lower() if len(tokens) == 2 and isinstance(tokens[1], str) else ''
normalized_unit = tokens[1].lower() if len(tokens) == 2 and isinstance(tokens[1], str) else ""
# If the input contains only a number, it's assumed to be the number of
# bytes. The second token can also explicitly reference the unit bytes.
if len(tokens) == 1 or normalized_unit.startswith('b'):
if len(tokens) == 1 or normalized_unit.startswith("b"):
return int(tokens[0])
# Otherwise we expect two tokens: A number and a unit.
if normalized_unit:
# Convert plural units to singular units, for details:
# https://github.com/xolox/python-humanfriendly/issues/26
normalized_unit = normalized_unit.rstrip('s')
normalized_unit = normalized_unit.rstrip("s")
for k, low, high in disk_size_units:
# First we check for unambiguous symbols (KiB, MiB, GiB, etc)
# and names (kibibyte, mebibyte, gibibyte, etc) because their
@ -271,8 +285,7 @@ def parse_size(size, binary=False):
# Now we will deal with ambiguous prefixes (K, M, G, etc),
# symbols (KB, MB, GB, etc) and names (kilobyte, megabyte,
# gigabyte, etc) according to the caller's preference.
if (normalized_unit in (low.lower(), high.lower()) or
normalized_unit.startswith(low.lower())):
if normalized_unit in (low.lower(), high.lower()) or normalized_unit.startswith(low.lower()):
return int(tokens[0] * k)
raise ValueError("Failed to parse size! (input {} was tokenized as {})".format(size, tokens))
@ -301,8 +314,7 @@ def get_common_path(list_of_files):
if f_parts[:num_p] == common_path_parts[:num_p]:
common_path_parts = common_path_parts[:num_p]
continue
num_p = min(
[i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1])
num_p = min([i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1])
# no common path, break
if num_p < 0:
common_path_parts = []
@ -317,3 +329,19 @@ def get_common_path(list_of_files):
return common_path.as_posix()
return None
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, numeric_owner=False):
"""Tarfile member sanitization (addresses CVE-2007-4559)"""
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)