Add force_download argument to Artifact.get() and Artifact.get_local_copy() (#319)

This commit is contained in:
H4dr1en 2021-03-09 16:42:45 +01:00 committed by GitHub
parent 41cad87b77
commit f067caca52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -135,20 +135,21 @@ class Artifact(object):
self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None
self._object = None self._object = None
def get(self): def get(self, force_download=False):
# type: () -> Any # type: (bool) -> Any
""" """
Return an object constructed from the artifact file Return an object constructed from the artifact file
Currently supported types: Numpy.array, pandas.DataFrame, PIL.Image, dict (json) Currently supported types: Numpy.array, pandas.DataFrame, PIL.Image, dict (json)
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory) All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory)
:param bool force_download: download file from remote even if exists in local cache
:return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path. :return: One of the following objects Numpy.array, pandas.DataFrame, PIL.Image, dict (json), or pathlib2.Path.
""" """
if self._object: if self._object:
return self._object return self._object
local_file = self.get_local_copy(raise_on_error=True) local_file = self.get_local_copy(raise_on_error=True, force_download=force_download)
# noinspection PyProtectedMember # noinspection PyProtectedMember
if self.type == 'numpy' and np: if self.type == 'numpy' and np:
@ -176,13 +177,14 @@ class Artifact(object):
return self._object return self._object
def get_local_copy(self, extract_archive=True, raise_on_error=False): def get_local_copy(self, extract_archive=True, raise_on_error=False, force_download=False):
# type: (bool, bool) -> str # type: (bool, bool, bool) -> str
""" """
:param bool extract_archive: If True and artifact is of type 'archive' (compressed folder) :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
The returned path will be a temporary folder containing the archive content The returned path will be a temporary folder containing the archive content
:param bool raise_on_error: If True and the artifact could not be downloaded, :param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning. raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: download file from remote even if exists in local cache
:raise: Raises error if local copy not found. :raise: Raises error if local copy not found.
:return: A local path to a downloaded copy of the artifact. :return: A local path to a downloaded copy of the artifact.
""" """
@ -190,7 +192,8 @@ class Artifact(object):
local_copy = StorageManager.get_local_copy( local_copy = StorageManager.get_local_copy(
remote_url=self.url, remote_url=self.url,
extract_archive=extract_archive and self.type == 'archive', extract_archive=extract_archive and self.type == 'archive',
name=self.name name=self.name,
force_download=force_download
) )
if raise_on_error and local_copy is None: if raise_on_error and local_copy is None:
raise ValueError( raise ValueError(