diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index 3e9dacbc..e83240ea 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -1,5 +1,9 @@ +import os +import shutil +import sys import threading import weakref +from tempfile import mkstemp import six from pathlib2 import Path @@ -200,7 +204,17 @@ class WeightsFileHandler(object): trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, target_filename=target_filename) else: - trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False) + # create a copy of the stored file, + # protect against someone deletes/renames the file before async upload finish is done + target_filename = Path(files[0]).name + # HACK: if pytorch-lightning is used, remove the temp '.part' file extension + if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'): + target_filename = target_filename[:-len('.part')] + fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp') + os.close(fd) + shutil.copy(files[0], temp_file) + trains_out_model.update_weights(weights_filename=temp_file, auto_delete_file=True, + target_filename=target_filename) else: trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) except Exception as ex: