Fix PY2 compatibility

This commit is contained in:
allegroai 2022-09-13 15:04:49 +03:00
parent bc1a243ecd
commit a18b7b3271
8 changed files with 50 additions and 32 deletions

View File

@ -141,7 +141,7 @@ class Artifact(object):
self._object = self._not_set self._object = self._not_set
def get(self, force_download=False, deserialization_function=None): def get(self, force_download=False, deserialization_function=None):
# type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any # type: (bool, Optional[Callable[bytes, Any]]) -> Any
""" """
Return an object constructed from the artifact file Return an object constructed from the artifact file
@ -156,7 +156,7 @@ class Artifact(object):
pointing to a local copy of the artifacts file (or directory) will be returned pointing to a local copy of the artifacts file (or directory) will be returned
:param bool force_download: download file from remote even if exists in local cache :param bool force_download: download file from remote even if exists in local cache
:param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`, :param Callable[bytes, Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object. which represents the serialized object. This function should return the deserialized object.
Useful when the artifact was uploaded using a custom serialization function when calling the Useful when the artifact was uploaded using a custom serialization function when calling the
`Task.upload_artifact` method with the `serialization_function` argument. `Task.upload_artifact` method with the `serialization_function` argument.
@ -426,9 +426,9 @@ class Artifacts(object):
artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes] artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes]
except Exception: except Exception:
artifact_type_data.preview = "" artifact_type_data.preview = ""
override_filename_ext_in_uri = "" override_filename_ext_in_uri = extension_name or ""
override_filename_in_uri = name override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".") fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri)
os.close(fd) os.close(fd)
# noinspection PyBroadException # noinspection PyBroadException
try: try:

View File

@ -6,12 +6,24 @@ except ImportError:
fire = None fire = None
import inspect import inspect
from types import SimpleNamespace
from .frameworks import _patched_call # noqa from .frameworks import _patched_call # noqa
from ..config import get_remote_task_id, running_remotely from ..config import get_remote_task_id, running_remotely
from ..utilities.dicts import cast_str_to_bool from ..utilities.dicts import cast_str_to_bool
class SimpleNamespace(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
keys = sorted(self.__dict__)
items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
return "{}({})".format(type(self).__name__, ", ".join(items))
def __eq__(self, other):
return self.__dict__ == other.__dict__
class PatchFire: class PatchFire:
_args = {} _args = {}
_command_type = "fire.Command" _command_type = "fire.Command"
@ -67,22 +79,22 @@ class PatchFire:
else: else:
args[cls._section_name + cls._args_sep + cls.__current_command] = True args[cls._section_name + cls._args_sep + cls.__current_command] = True
parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type
args = { args.update(
**args, {
**{
cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v
for k, v in cls._args.items() for k, v in cls._args.items()
if k in (PatchFire.__command_args.get(cls.__current_command) or []) if k in (PatchFire.__command_args.get(cls.__current_command) or [])
}, }
**{ )
args.update(
{
cls._section_name + cls._args_sep + k: v cls._section_name + cls._args_sep + k: v
for k, v in cls._args.items() for k, v in cls._args.items()
if k not in (PatchFire.__command_args.get(cls.__current_command) or []) if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
},
} }
parameters_types = { )
**parameters_types, parameters_types.update(
**{ {
cls._section_name cls._section_name
+ cls._args_sep + cls._args_sep
+ cls.__current_command + cls.__current_command
@ -90,13 +102,15 @@ class PatchFire:
+ k: cls._command_arg_type_template % cls.__current_command + k: cls._command_arg_type_template % cls.__current_command
for k in cls._args.keys() for k in cls._args.keys()
if k in (PatchFire.__command_args.get(cls.__current_command) or []) if k in (PatchFire.__command_args.get(cls.__current_command) or [])
}, }
**{ )
parameters_types.update(
{
cls._section_name + cls._args_sep + k: cls._shared_arg_type cls._section_name + cls._args_sep + k: cls._shared_arg_type
for k in cls._args.keys() for k in cls._args.keys()
if k not in (PatchFire.__command_args.get(cls.__current_command) or []) if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
},
} }
)
for command in cls.__commands: for command in cls.__commands:
if command == cls.__current_command: if command == cls.__current_command:
continue continue
@ -114,8 +128,8 @@ class PatchFire:
+ k: cls._command_arg_type_template % command + k: cls._command_arg_type_template % command
for k in (cls.__command_args.get(command) or []) for k in (cls.__command_args.get(command) or [])
} }
args = {**args, **unused_command_args} args.update(unused_command_args)
parameters_types = {**parameters_types, **unused_paramenters_types} parameters_types.update(unused_paramenters_types)
# noinspection PyProtectedMember # noinspection PyProtectedMember
cls._current_task._set_parameters( cls._current_task._set_parameters(
@ -186,7 +200,8 @@ class PatchFire:
fn_spec = fire.inspectutils.GetFullArgSpec(component) fn_spec = fire.inspectutils.GetFullArgSpec(component)
parse = fire.core._MakeParseFn(fn, metadata) # noqa parse = fire.core._MakeParseFn(fn, metadata) # noqa
(parsed_args, parsed_kwargs), _, _, _ = parse(args_) (parsed_args, parsed_kwargs), _, _, _ = parse(args_)
PatchFire._args = {**PatchFire._args, **{k: v for k, v in zip(fn_spec.args, parsed_args)}, **parsed_kwargs} PatchFire._args.update({k: v for k, v in zip(fn_spec.args, parsed_args)})
PatchFire._args.update(parsed_args)
PatchFire._update_task_args() PatchFire._update_task_args()
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs) return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)

View File

@ -3013,7 +3013,7 @@ class Dataset(object):
only_fields=["id", "runtime.version"], only_fields=["id", "runtime.version"],
search_hidden=True, search_hidden=True,
_allow_extra_fields_=True, _allow_extra_fields_=True,
**dataset_filter, **dataset_filter
) )
except Exception: except Exception:
datasets = [] datasets = []

View File

@ -1255,8 +1255,8 @@ class _HttpDriver(_Driver):
headers = { headers = {
'Content-Type': m.content_type, 'Content-Type': m.content_type,
**(container.get_headers(url) or {}),
} }
headers.update(container.get_headers(url) or {})
if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'): if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'):
pos = iterator.tell() pos = iterator.tell()
@ -1537,9 +1537,9 @@ class _Boto3Driver(_Driver):
stream = _Stream(iterator) stream = _Stream(iterator)
try: try:
extra_args = { extra_args = {
'ContentType': get_file_mimetype(object_name), 'ContentType': get_file_mimetype(object_name)
**(container.config.extra_args or {})
} }
extra_args.update(container.config.extra_args or {})
container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig( container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart, use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
@ -1556,9 +1556,9 @@ class _Boto3Driver(_Driver):
import boto3.s3.transfer import boto3.s3.transfer
try: try:
extra_args = { extra_args = {
'ContentType': get_file_mimetype(object_name or file_path), 'ContentType': get_file_mimetype(object_name or file_path)
**(container.config.extra_args or {})
} }
extra_args.update(container.config.extra_args or {})
container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig( container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart, use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1, max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,

View File

@ -1920,6 +1920,7 @@ class Task(_Task):
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` (default ``.csv.gz``) - pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` (default ``.csv.gz``)
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``) - numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
- PIL.Image - whatever extensions PIL supports (default ``.png``) - PIL.Image - whatever extensions PIL supports (default ``.png``)
- In case the ``serialization_function`` argument is set - any extension is supported
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one :param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray` parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray`

View File

@ -19,7 +19,7 @@ class InvalidVersion(ValueError):
@attrs @attrs
class _Version: class _Version(object):
epoch = attrib() epoch = attrib()
release = attrib() release = attrib()
dev = attrib() dev = attrib()

View File

@ -11,7 +11,8 @@ Pillow>=4.1.1
psutil>=3.4.2 psutil>=3.4.2
pyparsing>=2.0.3 pyparsing>=2.0.3
python-dateutil>=2.6.1 python-dateutil>=2.6.1
pyjwt>=2.4.0,<2.5.0 pyjwt>=2.4.0,<2.5.0 ; python_version > '3.5'
pyjwt>=1.6.4,<2.0.0 ; python_version <= '3.5'
PyYAML>=3.12 PyYAML>=3.12
requests>=2.20.0 requests>=2.20.0
six>=1.13.0 six>=1.13.0

View File

@ -6,10 +6,11 @@ https://github.com/allegroai/clearml
import os.path import os.path
# Always prefer setuptools over distutils # Always prefer setuptools over distutils
from setuptools import setup, find_packages from setuptools import setup, find_packages
import codecs
def read_text(filepath): def read_text(filepath):
with open(filepath, "r", encoding="utf-8") as f: with codecs.open(filepath, "r", encoding="utf-8") as f:
return f.read() return f.read()