Fix Model.download_model_weights() to reuse previously downloaded file

This commit is contained in:
allegroai 2020-03-22 18:11:30 +02:00
parent 477665ee33
commit 63507c82f7

View File

@ -376,6 +376,22 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def tags(self): def tags(self):
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags
@property
def task(self):
try:
return self.data.task
except ValueError:
# no task is yet specified
return None
@property
def uri(self):
try:
return self.data.uri
except ValueError:
# no uri is yet specified
return None
@property @property
def locked(self): def locked(self):
if self.id is None: if self.id is None:
@ -388,20 +404,19 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
if not uri or not uri.strip(): if not uri or not uri.strip():
return None return None
helper = StorageHelper.get(uri) # check if we already downloaded the file
filename = uri.split('/')[-1] downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
ext = '.'.join(filename.split('.')[1:]) for dl_file in downloaded_models:
fd, local_filename = mkstemp(suffix='.'+ext) if Path(dl_file).exists():
os.close(fd) return dl_file
local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True, verbose=True) # remove non existing model file
# if we ended up without any local copy, delete the temp file Model._local_model_to_id_uri.pop(dl_file, None)
if local_download != local_filename:
try: local_download = StorageHelper.get(uri).get_local_copy(uri)
Path(local_filename).unlink()
except Exception:
pass
# save local model, so we can later query what was the original one # save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
return local_download return local_download
@property @property