diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index 61706d78..5e4007cb 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -104,10 +104,19 @@ class PatchPyTorchModelIO(PatchBaseModelIO): @staticmethod def _save(original_fn, obj, f, *args, **kwargs): ret = original_fn(obj, f, *args, **kwargs) + # if there is no main task or this is a nested call if not PatchPyTorchModelIO.__main_task: return ret + # pytorch-lightning check if rank is zero + if hasattr(obj, 'is_global_zero'): + if not obj.is_global_zero: + return ret + elif hasattr(obj, 'trainer') and hasattr(obj.trainer, 'is_global_zero'): + if not obj.trainer.is_global_zero: + return ret + # noinspection PyBroadException try: if isinstance(f, six.string_types):