From 63507c82f7879d8f5d5486019b30748f97fd628f Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 22 Mar 2020 18:11:30 +0200 Subject: [PATCH] Fix Model.download_model_weights() to reuse previously downloaded file --- trains/backend_interface/model.py | 39 +++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/trains/backend_interface/model.py b/trains/backend_interface/model.py index c48b9360..a9bcdcf3 100644 --- a/trains/backend_interface/model.py +++ b/trains/backend_interface/model.py @@ -376,6 +376,22 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): def tags(self): return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags + @property + def task(self): + try: + return self.data.task + except ValueError: + # no task is yet specified + return None + + @property + def uri(self): + try: + return self.data.uri + except ValueError: + # no uri is yet specified + return None + @property def locked(self): if self.id is None: @@ -388,20 +404,19 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): if not uri or not uri.strip(): return None - helper = StorageHelper.get(uri) - filename = uri.split('/')[-1] - ext = '.'.join(filename.split('.')[1:]) - fd, local_filename = mkstemp(suffix='.'+ext) - os.close(fd) - local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True, verbose=True) - # if we ended up without any local copy, delete the temp file - if local_download != local_filename: - try: - Path(local_filename).unlink() - except Exception: - pass + # check if we already downloaded the file + downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri] + for dl_file in downloaded_models: + if Path(dl_file).exists(): + return dl_file + # remove non existing model file + Model._local_model_to_id_uri.pop(dl_file, None) + + local_download = StorageHelper.get(uri).get_local_copy(uri) + # save local model, so we can later query what was the original one Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) + return local_download @property