Fix PY2 compatibility

This commit is contained in:
allegroai 2022-09-13 15:04:49 +03:00
parent bc1a243ecd
commit a18b7b3271
8 changed files with 50 additions and 32 deletions

View File

@ -141,7 +141,7 @@ class Artifact(object):
self._object = self._not_set
def get(self, force_download=False, deserialization_function=None):
# type: (bool, Optional[Callable[Union[bytes], Any]]) -> Any
# type: (bool, Optional[Callable[bytes, Any]]) -> Any
"""
Return an object constructed from the artifact file
@ -156,7 +156,7 @@ class Artifact(object):
pointing to a local copy of the artifacts file (or directory) will be returned
:param bool force_download: download file from remote even if exists in local cache
:param Callable[Union[bytes], Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
:param Callable[bytes, Any] deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
Useful when the artifact was uploaded using a custom serialization function when calling the
`Task.upload_artifact` method with the `serialization_function` argument.
@ -426,9 +426,9 @@ class Artifacts(object):
artifact_type_data.preview = preview or str(artifact_object.__repr__())[:self.max_preview_size_bytes]
except Exception:
artifact_type_data.preview = ""
override_filename_ext_in_uri = ""
override_filename_in_uri = name
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".")
override_filename_ext_in_uri = extension_name or ""
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri)
os.close(fd)
# noinspection PyBroadException
try:

View File

@ -6,12 +6,24 @@ except ImportError:
fire = None
import inspect
from types import SimpleNamespace
from .frameworks import _patched_call # noqa
from ..config import get_remote_task_id, running_remotely
from ..utilities.dicts import cast_str_to_bool
class SimpleNamespace(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
keys = sorted(self.__dict__)
items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
return "{}({})".format(type(self).__name__, ", ".join(items))
def __eq__(self, other):
return self.__dict__ == other.__dict__
class PatchFire:
_args = {}
_command_type = "fire.Command"
@ -67,22 +79,22 @@ class PatchFire:
else:
args[cls._section_name + cls._args_sep + cls.__current_command] = True
parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type
args = {
**args,
**{
args.update(
{
cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v
for k, v in cls._args.items()
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
},
**{
}
)
args.update(
{
cls._section_name + cls._args_sep + k: v
for k, v in cls._args.items()
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
},
}
parameters_types = {
**parameters_types,
**{
)
parameters_types.update(
{
cls._section_name
+ cls._args_sep
+ cls.__current_command
@ -90,13 +102,15 @@ class PatchFire:
+ k: cls._command_arg_type_template % cls.__current_command
for k in cls._args.keys()
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
},
**{
}
)
parameters_types.update(
{
cls._section_name + cls._args_sep + k: cls._shared_arg_type
for k in cls._args.keys()
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
},
}
)
for command in cls.__commands:
if command == cls.__current_command:
continue
@ -114,8 +128,8 @@ class PatchFire:
+ k: cls._command_arg_type_template % command
for k in (cls.__command_args.get(command) or [])
}
args = {**args, **unused_command_args}
parameters_types = {**parameters_types, **unused_paramenters_types}
args.update(unused_command_args)
parameters_types.update(unused_paramenters_types)
# noinspection PyProtectedMember
cls._current_task._set_parameters(
@ -186,7 +200,8 @@ class PatchFire:
fn_spec = fire.inspectutils.GetFullArgSpec(component)
parse = fire.core._MakeParseFn(fn, metadata) # noqa
(parsed_args, parsed_kwargs), _, _, _ = parse(args_)
PatchFire._args = {**PatchFire._args, **{k: v for k, v in zip(fn_spec.args, parsed_args)}, **parsed_kwargs}
PatchFire._args.update({k: v for k, v in zip(fn_spec.args, parsed_args)})
PatchFire._args.update(parsed_args)
PatchFire._update_task_args()
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)

View File

@ -3013,7 +3013,7 @@ class Dataset(object):
only_fields=["id", "runtime.version"],
search_hidden=True,
_allow_extra_fields_=True,
**dataset_filter,
**dataset_filter
)
except Exception:
datasets = []

View File

@ -1255,8 +1255,8 @@ class _HttpDriver(_Driver):
headers = {
'Content-Type': m.content_type,
**(container.get_headers(url) or {}),
}
headers.update(container.get_headers(url) or {})
if hasattr(iterator, 'tell') and hasattr(iterator, 'seek'):
pos = iterator.tell()
@ -1537,9 +1537,9 @@ class _Boto3Driver(_Driver):
stream = _Stream(iterator)
try:
extra_args = {
'ContentType': get_file_mimetype(object_name),
**(container.config.extra_args or {})
'ContentType': get_file_mimetype(object_name)
}
extra_args.update(container.config.extra_args or {})
container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,
@ -1556,9 +1556,9 @@ class _Boto3Driver(_Driver):
import boto3.s3.transfer
try:
extra_args = {
'ContentType': get_file_mimetype(object_name or file_path),
**(container.config.extra_args or {})
'ContentType': get_file_mimetype(object_name or file_path)
}
extra_args.update(container.config.extra_args or {})
container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart,
max_concurrency=self._max_multipart_concurrency if container.config.multipart else 1,

View File

@ -1920,6 +1920,7 @@ class Task(_Task):
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` (default ``.csv.gz``)
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
- PIL.Image - whatever extensions PIL supports (default ``.png``)
- In case the ``serialization_function`` argument is set - any extension is supported
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
parameter of any types which is the object to be serialized. The function should return a `bytes` or `bytearray`

View File

@ -19,7 +19,7 @@ class InvalidVersion(ValueError):
@attrs
class _Version:
class _Version(object):
epoch = attrib()
release = attrib()
dev = attrib()

View File

@ -11,7 +11,8 @@ Pillow>=4.1.1
psutil>=3.4.2
pyparsing>=2.0.3
python-dateutil>=2.6.1
pyjwt>=2.4.0,<2.5.0
pyjwt>=2.4.0,<2.5.0 ; python_version > '3.5'
pyjwt>=1.6.4,<2.0.0 ; python_version <= '3.5'
PyYAML>=3.12
requests>=2.20.0
six>=1.13.0

View File

@ -6,10 +6,11 @@ https://github.com/allegroai/clearml
import os.path
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
import codecs
def read_text(filepath):
with open(filepath, "r", encoding="utf-8") as f:
with codecs.open(filepath, "r", encoding="utf-8") as f:
return f.read()