mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Refactor inner functions
This commit is contained in:
		
							parent
							
								
									8c776b6da0
								
							
						
					
					
						commit
						ff73e7848c
					
				| @ -14,7 +14,7 @@ from pathlib2 import Path | ||||
| 
 | ||||
| from .cache import CacheManager | ||||
| from .helper import StorageHelper | ||||
| from .util import encode_string_to_filename | ||||
| from .util import encode_string_to_filename, safe_extract | ||||
| from ..debugging.log import LoggerRoot | ||||
| 
 | ||||
| 
 | ||||
| @ -165,47 +165,9 @@ class StorageManager(object): | ||||
|                 ZipFile(cached_file.as_posix()).extractall(path=temp_target_folder.as_posix()) | ||||
|             elif suffix == ".tar.gz": | ||||
|                 with tarfile.open(cached_file.as_posix()) as file: | ||||
|                     def is_within_directory(directory, target): | ||||
|                          | ||||
|                         abs_directory = os.path.abspath(directory) | ||||
|                         abs_target = os.path.abspath(target) | ||||
|                      | ||||
|                         prefix = os.path.commonprefix([abs_directory, abs_target]) | ||||
|                          | ||||
|                         return prefix == abs_directory | ||||
|                      | ||||
|                     def safe_extract(tar, path=".", members=None, *, numeric_owner=False): | ||||
|                      | ||||
|                         for member in tar.getmembers(): | ||||
|                             member_path = os.path.join(path, member.name) | ||||
|                             if not is_within_directory(path, member_path): | ||||
|                                 raise Exception("Attempted Path Traversal in Tar File") | ||||
|                      | ||||
|                         tar.extractall(path, members, numeric_owner=numeric_owner)  | ||||
|                          | ||||
|                      | ||||
|                     safe_extract(file, temp_target_folder.as_posix()) | ||||
|             elif suffix == ".tgz": | ||||
|                 with tarfile.open(cached_file.as_posix(), mode='r:gz') as file: | ||||
|                     def is_within_directory(directory, target): | ||||
|                          | ||||
|                         abs_directory = os.path.abspath(directory) | ||||
|                         abs_target = os.path.abspath(target) | ||||
|                      | ||||
|                         prefix = os.path.commonprefix([abs_directory, abs_target]) | ||||
|                          | ||||
|                         return prefix == abs_directory | ||||
|                      | ||||
|                     def safe_extract(tar, path=".", members=None, *, numeric_owner=False): | ||||
|                      | ||||
|                         for member in tar.getmembers(): | ||||
|                             member_path = os.path.join(path, member.name) | ||||
|                             if not is_within_directory(path, member_path): | ||||
|                                 raise Exception("Attempted Path Traversal in Tar File") | ||||
|                      | ||||
|                         tar.extractall(path, members, numeric_owner=numeric_owner)  | ||||
|                          | ||||
|                      | ||||
|                     safe_extract(file, temp_target_folder.as_posix()) | ||||
| 
 | ||||
|             if temp_target_folder != target_folder: | ||||
|  | ||||
| @ -1,14 +1,15 @@ | ||||
| import fnmatch | ||||
| import hashlib | ||||
| import json | ||||
| import os.path | ||||
| import re | ||||
| import sys | ||||
| from zlib import crc32 | ||||
| from typing import Optional, Union, Sequence, Dict | ||||
| from pathlib2 import Path | ||||
| from zlib import crc32 | ||||
| 
 | ||||
| from six.moves.urllib.parse import quote, urlparse, urlunparse | ||||
| import six | ||||
| import fnmatch | ||||
| from pathlib2 import Path | ||||
| from six.moves.urllib.parse import quote, urlparse, urlunparse | ||||
| 
 | ||||
| from ..debugging.log import LoggerRoot | ||||
| 
 | ||||
| @ -16,11 +17,13 @@ from ..debugging.log import LoggerRoot | ||||
| def get_config_object_matcher(**patterns): | ||||
|     unsupported = {k: v for k, v in patterns.items() if not isinstance(v, six.string_types)} | ||||
|     if unsupported: | ||||
|         raise ValueError('Unsupported object matcher (expecting string): %s' | ||||
|                          % ', '.join('%s=%s' % (k, v) for k, v in unsupported.items())) | ||||
|         raise ValueError( | ||||
|             "Unsupported object matcher (expecting string): %s" | ||||
|             % ", ".join("%s=%s" % (k, v) for k, v in unsupported.items()) | ||||
|         ) | ||||
| 
 | ||||
|     # optimize simple patters | ||||
|     starts_with = {k: v.rstrip('*') for k, v in patterns.items() if '*' not in v.rstrip('*') and '?' not in v} | ||||
|     starts_with = {k: v.rstrip("*") for k, v in patterns.items() if "*" not in v.rstrip("*") and "?" not in v} | ||||
|     patterns = {k: v for k, v in patterns.items() if v not in starts_with} | ||||
| 
 | ||||
|     def _matcher(**kwargs): | ||||
| @ -60,7 +63,7 @@ def sha256sum(filename, skip_header=0, block_size=65536): | ||||
|     b = bytearray(block_size) | ||||
|     mv = memoryview(b) | ||||
|     try: | ||||
|         with open(filename, 'rb', buffering=0) as f: | ||||
|         with open(filename, "rb", buffering=0) as f: | ||||
|             # skip header | ||||
|             if skip_header: | ||||
|                 file_hash.update(f.read(skip_header)) | ||||
| @ -86,7 +89,7 @@ def md5text(text, seed=1337): | ||||
|     :param seed: use prefix seed for hashing | ||||
|     :return: md5 string | ||||
|     """ | ||||
|     return hash_text(text=text, seed=seed, hash_func='md5') | ||||
|     return hash_text(text=text, seed=seed, hash_func="md5") | ||||
| 
 | ||||
| 
 | ||||
| def crc32text(text, seed=1337): | ||||
| @ -99,10 +102,10 @@ def crc32text(text, seed=1337): | ||||
|     :param seed: use prefix seed for hashing | ||||
|     :return: crc32 hex in string (32bits = 8 characters in hex) | ||||
|     """ | ||||
|     return '{:08x}'.format(crc32((str(seed)+str(text)).encode('utf-8'))) | ||||
|     return "{:08x}".format(crc32((str(seed) + str(text)).encode("utf-8"))) | ||||
| 
 | ||||
| 
 | ||||
| def hash_text(text, seed=1337, hash_func='md5'): | ||||
| def hash_text(text, seed=1337, hash_func="md5"): | ||||
|     # type: (str, Union[int, str], str) -> str | ||||
|     """ | ||||
|     Return hash_func (md5/sha1/sha256/sha384/sha512) hash of a string | ||||
| @ -112,13 +115,13 @@ def hash_text(text, seed=1337, hash_func='md5'): | ||||
|     :param hash_func: hashing function. currently supported md5 sha256 | ||||
|     :return: hashed string | ||||
|     """ | ||||
|     assert hash_func in ('md5', 'sha256', 'sha256', 'sha384', 'sha512') | ||||
|     assert hash_func in ("md5", "sha256", "sha256", "sha384", "sha512") | ||||
|     h = getattr(hashlib, hash_func)() | ||||
|     h.update((str(seed) + str(text)).encode('utf-8')) | ||||
|     h.update((str(seed) + str(text)).encode("utf-8")) | ||||
|     return h.hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def hash_dict(a_dict, seed=1337, hash_func='md5'): | ||||
| def hash_dict(a_dict, seed=1337, hash_func="md5"): | ||||
|     # type: (Dict, Union[int, str], str) -> str | ||||
|     """ | ||||
|     Return hash_func (crc32/md5/sha1/sha256/sha384/sha512) hash of the dict values | ||||
| @ -129,9 +132,9 @@ def hash_dict(a_dict, seed=1337, hash_func='md5'): | ||||
|     :param hash_func: hashing function. currently supported md5 sha256 | ||||
|     :return: hashed string | ||||
|     """ | ||||
|     assert hash_func in ('crc32', 'md5', 'sha256', 'sha256', 'sha384', 'sha512') | ||||
|     assert hash_func in ("crc32", "md5", "sha256", "sha256", "sha384", "sha512") | ||||
|     repr_string = json.dumps(a_dict, sort_keys=True) | ||||
|     if hash_func == 'crc32': | ||||
|     if hash_func == "crc32": | ||||
|         return crc32text(repr_string, seed=seed) | ||||
|     else: | ||||
|         return hash_text(repr_string, seed=seed, hash_func=hash_func) | ||||
| @ -141,7 +144,7 @@ def is_windows(): | ||||
|     """ | ||||
|     :return: True if currently running on windows OS | ||||
|     """ | ||||
|     return sys.platform == 'win32' | ||||
|     return sys.platform == "win32" | ||||
| 
 | ||||
| 
 | ||||
| def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b_instead_of_bytes=False): | ||||
| @ -185,7 +188,7 @@ def format_size(size_in_bytes, binary=False, use_nonbinary_notation=False, use_b | ||||
|     for i, m in enumerate(scale): | ||||
|         if size < k ** (i + 1) or i == len(scale) - 1: | ||||
|             return ( | ||||
|                 ("{:.2f}".format(size / (k ** i)).rstrip("0").rstrip(".") if i > 0 else "{}".format(int(size))) | ||||
|                 ("{:.2f}".format(size / (k**i)).rstrip("0").rstrip(".") if i > 0 else "{}".format(int(size))) | ||||
|                 + " " | ||||
|                 + m | ||||
|             ) | ||||
| @ -226,42 +229,53 @@ def parse_size(size, binary=False): | ||||
|         >>> parse_size('1.5 GB', binary=True) | ||||
|         1610612736 | ||||
|     """ | ||||
| 
 | ||||
|     def tokenize(text): | ||||
|         tokenized_input = [] | ||||
|         for token in re.split(r'(\d+(?:\.\d+)?)', text): | ||||
|         for token in re.split(r"(\d+(?:\.\d+)?)", text): | ||||
|             token = token.strip() | ||||
|             if re.match(r'\d+\.\d+', token): | ||||
|             if re.match(r"\d+\.\d+", token): | ||||
|                 tokenized_input.append(float(token)) | ||||
|             elif token.isdigit(): | ||||
|                 tokenized_input.append(int(token)) | ||||
|             elif token: | ||||
|                 tokenized_input.append(token) | ||||
|         return tokenized_input | ||||
| 
 | ||||
|     tokens = tokenize(str(size)) | ||||
|     if tokens and isinstance(tokens[0], (int, float)): | ||||
|         disk_size_units_b = \ | ||||
|             (('B', 'bytes'), ('KiB', 'kibibyte'), ('MiB', 'mebibyte'), ('GiB', 'gibibyte'), | ||||
|              ('TiB', 'tebibyte'), ('PiB', 'pebibyte')) | ||||
|         disk_size_units_d = \ | ||||
|             (('B', 'bytes'), ('KB', 'kilobyte'), ('MB', 'megabyte'), ('GB', 'gigabyte'), | ||||
|              ('TB', 'terabyte'), ('PB', 'petabyte')) | ||||
|         disk_size_units_b = [(1024 ** i, s[0], s[1]) for i, s in enumerate(disk_size_units_b)] | ||||
|         disk_size_units_b = ( | ||||
|             ("B", "bytes"), | ||||
|             ("KiB", "kibibyte"), | ||||
|             ("MiB", "mebibyte"), | ||||
|             ("GiB", "gibibyte"), | ||||
|             ("TiB", "tebibyte"), | ||||
|             ("PiB", "pebibyte"), | ||||
|         ) | ||||
|         disk_size_units_d = ( | ||||
|             ("B", "bytes"), | ||||
|             ("KB", "kilobyte"), | ||||
|             ("MB", "megabyte"), | ||||
|             ("GB", "gigabyte"), | ||||
|             ("TB", "terabyte"), | ||||
|             ("PB", "petabyte"), | ||||
|         ) | ||||
|         disk_size_units_b = [(1024**i, s[0], s[1]) for i, s in enumerate(disk_size_units_b)] | ||||
|         k = 1024 if binary else 1000 | ||||
|         disk_size_units_d = [(k ** i, s[0], s[1]) for i, s in enumerate(disk_size_units_d)] | ||||
|         disk_size_units = (disk_size_units_b + disk_size_units_d) \ | ||||
|             if binary else (disk_size_units_d + disk_size_units_b) | ||||
|         disk_size_units_d = [(k**i, s[0], s[1]) for i, s in enumerate(disk_size_units_d)] | ||||
|         disk_size_units = (disk_size_units_b + disk_size_units_d) if binary else (disk_size_units_d + disk_size_units_b) | ||||
| 
 | ||||
|         # Get the normalized unit (if any) from the tokenized input. | ||||
|         normalized_unit = tokens[1].lower() if len(tokens) == 2 and isinstance(tokens[1], str) else '' | ||||
|         normalized_unit = tokens[1].lower() if len(tokens) == 2 and isinstance(tokens[1], str) else "" | ||||
|         # If the input contains only a number, it's assumed to be the number of | ||||
|         # bytes. The second token can also explicitly reference the unit bytes. | ||||
|         if len(tokens) == 1 or normalized_unit.startswith('b'): | ||||
|         if len(tokens) == 1 or normalized_unit.startswith("b"): | ||||
|             return int(tokens[0]) | ||||
|         # Otherwise we expect two tokens: A number and a unit. | ||||
|         if normalized_unit: | ||||
|             # Convert plural units to singular units, for details: | ||||
|             # https://github.com/xolox/python-humanfriendly/issues/26 | ||||
|             normalized_unit = normalized_unit.rstrip('s') | ||||
|             normalized_unit = normalized_unit.rstrip("s") | ||||
|             for k, low, high in disk_size_units: | ||||
|                 # First we check for unambiguous symbols (KiB, MiB, GiB, etc) | ||||
|                 # and names (kibibyte, mebibyte, gibibyte, etc) because their | ||||
| @ -271,8 +285,7 @@ def parse_size(size, binary=False): | ||||
|                 # Now we will deal with ambiguous prefixes (K, M, G, etc), | ||||
|                 # symbols (KB, MB, GB, etc) and names (kilobyte, megabyte, | ||||
|                 # gigabyte, etc) according to the caller's preference. | ||||
|                 if (normalized_unit in (low.lower(), high.lower()) or | ||||
|                         normalized_unit.startswith(low.lower())): | ||||
|                 if normalized_unit in (low.lower(), high.lower()) or normalized_unit.startswith(low.lower()): | ||||
|                     return int(tokens[0] * k) | ||||
| 
 | ||||
|     raise ValueError("Failed to parse size! (input {} was tokenized as {})".format(size, tokens)) | ||||
| @ -301,8 +314,7 @@ def get_common_path(list_of_files): | ||||
|         if f_parts[:num_p] == common_path_parts[:num_p]: | ||||
|             common_path_parts = common_path_parts[:num_p] | ||||
|             continue | ||||
|         num_p = min( | ||||
|             [i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1]) | ||||
|         num_p = min([i for i, (a, b) in enumerate(zip(common_path_parts[:num_p], f_parts[:num_p])) if a != b] or [-1]) | ||||
|         # no common path, break | ||||
|         if num_p < 0: | ||||
|             common_path_parts = [] | ||||
| @ -317,3 +329,19 @@ def get_common_path(list_of_files): | ||||
|         return common_path.as_posix() | ||||
| 
 | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def is_within_directory(directory, target): | ||||
|     abs_directory = os.path.abspath(directory) | ||||
|     abs_target = os.path.abspath(target) | ||||
|     prefix = os.path.commonprefix([abs_directory, abs_target]) | ||||
|     return prefix == abs_directory | ||||
| 
 | ||||
| 
 | ||||
| def safe_extract(tar, path=".", members=None, numeric_owner=False): | ||||
|     """Tarfile member sanitization (addresses CVE-2007-4559)""" | ||||
|     for member in tar.getmembers(): | ||||
|         member_path = os.path.join(path, member.name) | ||||
|         if not is_within_directory(path, member_path): | ||||
|             raise Exception("Attempted Path Traversal in Tar File") | ||||
|     tar.extractall(path, members, numeric_owner=numeric_owner) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai