Fix Model.get_weights_package() returns None on error

This commit is contained in:
allegroai 2021-11-30 21:14:18 +02:00
parent 11892a2145
commit 297f33703f

View File

@ -286,7 +286,7 @@ class BaseModel(object):
return self._get_base_model().download_model_weights(raise_on_error=raise_on_error)
def get_weights_package(self, return_path=False, raise_on_error=False):
# type: (bool, bool) -> Union[str, List[Path]]
# type: (bool, bool) -> Optional[Union[str, List[Path]]]
"""
Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames.
@ -300,6 +300,7 @@ class BaseModel(object):
raise ValueError, otherwise return None on failure and output log warning.
:return: The model weights, or a list of the locally stored filenames.
if raise_on_error=False, returns None on error.
"""
# check if model was packaged
if not self._is_package():
@ -308,6 +309,11 @@ class BaseModel(object):
# download packaged model
packed_file = self.get_weights(raise_on_error=raise_on_error)
if not packed_file:
if raise_on_error:
raise ValueError('Model package \'{}\' could not be downloaded'.format(self.url))
return None
# unpack
target_folder = mkdtemp(prefix='model_package_')
if not target_folder: