Add type hints, remove/ignore pep8 warnings

This commit is contained in:
allegroai 2020-05-31 12:02:15 +03:00
parent 7b3a554fe9
commit 183ad248cf

View File

@ -17,7 +17,7 @@ import six
from PIL import Image from PIL import Image
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from typing import Dict, Union, Optional, Any from typing import Dict, Union, Optional, Any, Sequence
from ..backend_api import Session from ..backend_api import Session
from ..backend_api.services import tasks from ..backend_api.services import tasks
@ -27,8 +27,10 @@ from ..storage.helper import remote_driver_schemes
try: try:
import pandas as pd import pandas as pd
DataFrame = pd.DataFrame
except ImportError: except ImportError:
pd = None pd = None
DataFrame = None
try: try:
import numpy as np import numpy as np
except ImportError: except ImportError:
@ -144,6 +146,7 @@ class Artifact(object):
local_file = self.get_local_copy(raise_on_error=True) local_file = self.get_local_copy(raise_on_error=True)
# noinspection PyProtectedMember
if self.type == 'numpy' and np: if self.type == 'numpy' and np:
self._object = np.load(local_file)[self.name] self._object = np.load(local_file)[self.name]
elif self.type in ('pandas', Artifacts._pd_artifact_type) and pd: elif self.type in ('pandas', Artifacts._pd_artifact_type) and pd:
@ -261,7 +264,7 @@ class Artifacts(object):
self._storage_prefix = None self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, object, Optional[dict], bool) -> () # type: (str, DataFrame, Optional[dict], Union[bool, Sequence[str]]) -> ()
""" """
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists. :param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame :param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
@ -355,6 +358,7 @@ class Artifacts(object):
artifact_object = Path(artifact_object) artifact_object = Path(artifact_object)
artifact_object.expanduser().absolute() artifact_object.expanduser().absolute()
# noinspection PyBroadException
try: try:
create_zip_file = not artifact_object.is_file() create_zip_file = not artifact_object.is_file()
except Exception: # Hack for windows pathlib2 bug, is_file isn't valid. except Exception: # Hack for windows pathlib2 bug, is_file isn't valid.
@ -392,7 +396,7 @@ class Artifacts(object):
# failed uploading folder: # failed uploading folder:
LoggerRoot.get_base_logger().warning('Exception {}\nFailed zipping artifact folder {}'.format( LoggerRoot.get_base_logger().warning('Exception {}\nFailed zipping artifact folder {}'.format(
folder, e)) folder, e))
return None return False
finally: finally:
os.close(fd) os.close(fd)
@ -467,7 +471,7 @@ class Artifacts(object):
self._flush_event.set() self._flush_event.set()
def stop(self, wait=True): def stop(self, wait=True):
# type: (str) -> () # type: (bool) -> ()
# stop the daemon thread and quit # stop the daemon thread and quit
# wait until thread exists # wait until thread exists
self._exit_flag = True self._exit_flag = True
@ -477,6 +481,7 @@ class Artifacts(object):
self._thread.join() self._thread.join()
# remove all temp folders # remove all temp folders
for f in self._temp_folder: for f in self._temp_folder:
# noinspection PyBroadException
try: try:
Path(f).rmdir() Path(f).rmdir()
except Exception: except Exception:
@ -535,6 +540,7 @@ class Artifacts(object):
previous_sha2 = self._last_artifacts_upload[name] previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2: if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload # nothing to do, we can skip the upload
# noinspection PyBroadException
try: try:
local_csv.unlink() local_csv.unlink()
except Exception: except Exception:
@ -604,6 +610,7 @@ class Artifacts(object):
_, uri = ev.get_target_full_upload_uri(upload_uri) _, uri = ev.get_target_full_upload_uri(upload_uri)
# send for upload # send for upload
# noinspection PyProtectedMember
self._task.reporter._report(ev) self._task.reporter._report(ev)
return uri return uri
@ -659,11 +666,13 @@ class Artifacts(object):
# build intersection summary # build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary): for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format( summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100 * len(unique_hash) / float(shape[0])) name=name, shape=shape, unique=len(unique_hash),
percentage=100 * len(unique_hash) / float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i + 1:]: for name2, shape2, unique_hash2 in artifacts_summary[i + 1:]:
intersection = len(unique_hash & unique_hash2) intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format( summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100 * intersection / float(len(unique_hash2))) name2=name2, intersection=intersection,
percentage=100 * intersection / float(len(unique_hash2)))
except Exception as e: except Exception as e:
LoggerRoot.get_base_logger().warning(str(e)) LoggerRoot.get_base_logger().warning(str(e))
finally: finally:
@ -682,6 +691,7 @@ class Artifacts(object):
def _get_storage_uri_prefix(self): def _get_storage_uri_prefix(self):
# type: () -> str # type: () -> str
if not self._storage_prefix: if not self._storage_prefix:
# noinspection PyProtectedMember
self._storage_prefix = self._task._get_output_destination_suffix() self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix return self._storage_prefix
@ -699,6 +709,7 @@ class Artifacts(object):
# skip header # skip header
if skip_header: if skip_header:
file_hash.update(f.read(skip_header)) file_hash.update(f.read(skip_header))
# noinspection PyUnresolvedReferences
for n in iter(lambda: f.readinto(mv), 0): for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n]) h.update(mv[:n])
if skip_header: if skip_header: