Add type hints, remove/ignore pep8 warnings

This commit is contained in:
allegroai 2020-05-31 12:02:15 +03:00
parent 7b3a554fe9
commit 183ad248cf

View File

@ -17,7 +17,7 @@ import six
from PIL import Image
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from typing import Dict, Union, Optional, Any
from typing import Dict, Union, Optional, Any, Sequence
from ..backend_api import Session
from ..backend_api.services import tasks
@ -27,8 +27,10 @@ from ..storage.helper import remote_driver_schemes
try:
import pandas as pd
DataFrame = pd.DataFrame
except ImportError:
pd = None
DataFrame = None
try:
import numpy as np
except ImportError:
@ -144,6 +146,7 @@ class Artifact(object):
local_file = self.get_local_copy(raise_on_error=True)
# noinspection PyProtectedMember
if self.type == 'numpy' and np:
self._object = np.load(local_file)[self.name]
elif self.type in ('pandas', Artifacts._pd_artifact_type) and pd:
@ -261,7 +264,7 @@ class Artifacts(object):
self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, object, Optional[dict], bool) -> ()
# type: (str, DataFrame, Optional[dict], Union[bool, Sequence[str]]) -> ()
"""
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
@ -355,6 +358,7 @@ class Artifacts(object):
artifact_object = Path(artifact_object)
artifact_object.expanduser().absolute()
# noinspection PyBroadException
try:
create_zip_file = not artifact_object.is_file()
except Exception: # Hack for windows pathlib2 bug, is_file isn't valid.
@ -392,7 +396,7 @@ class Artifacts(object):
# failed uploading folder:
LoggerRoot.get_base_logger().warning('Exception {}\nFailed zipping artifact folder {}'.format(
folder, e))
return None
return False
finally:
os.close(fd)
@ -467,7 +471,7 @@ class Artifacts(object):
self._flush_event.set()
def stop(self, wait=True):
# type: (str) -> ()
# type: (bool) -> ()
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
@ -477,6 +481,7 @@ class Artifacts(object):
self._thread.join()
# remove all temp folders
for f in self._temp_folder:
# noinspection PyBroadException
try:
Path(f).rmdir()
except Exception:
@ -535,6 +540,7 @@ class Artifacts(object):
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
# noinspection PyBroadException
try:
local_csv.unlink()
except Exception:
@ -604,6 +610,7 @@ class Artifacts(object):
_, uri = ev.get_target_full_upload_uri(upload_uri)
# send for upload
# noinspection PyProtectedMember
self._task.reporter._report(ev)
return uri
@ -659,11 +666,13 @@ class Artifacts(object):
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100 * len(unique_hash) / float(shape[0]))
name=name, shape=shape, unique=len(unique_hash),
percentage=100 * len(unique_hash) / float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i + 1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100 * intersection / float(len(unique_hash2)))
name2=name2, intersection=intersection,
percentage=100 * intersection / float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:
@ -682,6 +691,7 @@ class Artifacts(object):
def _get_storage_uri_prefix(self):
# type: () -> str
if not self._storage_prefix:
# noinspection PyProtectedMember
self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix
@ -699,6 +709,7 @@ class Artifacts(object):
# skip header
if skip_header:
file_hash.update(f.read(skip_header))
# noinspection PyUnresolvedReferences
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
if skip_header: