From 63e7cbab3049c9f32bce9ee205ba98cc455ec477 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 15 Sep 2022 15:59:34 +0300 Subject: [PATCH] Improve pytorch model saving better, add support for mmvc --- clearml/binding/frameworks/__init__.py | 9 +++- clearml/binding/frameworks/pytorch_bind.py | 57 ++++++++++++++++++++-- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/clearml/binding/frameworks/__init__.py b/clearml/binding/frameworks/__init__.py index 97125987..60039e3d 100644 --- a/clearml/binding/frameworks/__init__.py +++ b/clearml/binding/frameworks/__init__.py @@ -44,6 +44,13 @@ def _patched_call(original_fn, patched_fn): return _inner_patch +def _patched_call_no_recursion_guard(original_fn, patched_fn): + def _inner_patch(*args, **kwargs): + return patched_fn(original_fn, *args, **kwargs) + + return _inner_patch + + class _Empty(object): def __init__(self): self.trains_in_model = None @@ -159,7 +166,7 @@ class WeightsFileHandler(object): If the callback was already added, return the existing handle. :param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback - :return True if callback removed, False otherwise + :return: True if callback removed, False otherwise """ return cls._remove_callback(handle, cls._model_post_callbacks) diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index 4cfea85a..9b8f8255 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -1,10 +1,12 @@ import sys import six +import threading + from pathlib2 import Path from ...binding.frameworks.base_bind import PatchBaseModelIO -from ..frameworks import _patched_call, WeightsFileHandler, _Empty +from ..frameworks import _patched_call, _patched_call_no_recursion_guard, WeightsFileHandler, _Empty from ..import_bind import PostImportHookPatching from ...config import running_remotely from ...model import Framework @@ -12,8 +14,11 @@ from ...model import Framework class PatchPyTorchModelIO(PatchBaseModelIO): _current_task = None + _checkpoint_filename = {} __patched = None __patched_lightning = None + __patched_mmcv = None + __default_checkpoint_filename_counter = {} @staticmethod def update_current_task(task, **_): @@ -22,6 +27,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): return PatchPyTorchModelIO._patch_model_io() PatchPyTorchModelIO._patch_lightning_io() + PatchPyTorchModelIO._patch_mmcv() PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io) @@ -65,6 +71,41 @@ class PatchPyTorchModelIO(PatchBaseModelIO): except Exception: pass # print('Failed patching pytorch') + @staticmethod + def _patch_mmcv(): + if PatchPyTorchModelIO.__patched_mmcv: + return + if "mmcv" not in sys.modules: + return + PatchPyTorchModelIO.__patched_mmcv = True + + # noinspection PyBroadException + try: + from mmcv.runner import epoch_based_runner, iter_based_runner + + # we don't want the recursion check here because it guards pytorch's patched save functions + # which we need in order to log the saved model/checkpoint + epoch_based_runner.save_checkpoint = _patched_call_no_recursion_guard( + epoch_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint + ) + iter_based_runner.save_checkpoint = _patched_call_no_recursion_guard( + iter_based_runner.save_checkpoint, PatchPyTorchModelIO._mmcv_save_checkpoint + ) + except Exception: + pass + + @staticmethod + def _mmcv_save_checkpoint(original_fn, model, filename, *args, **kwargs): + # note that mmcv.runner.save_checkpoint doesn't return anything, hence the need for this + # patch function, but we return from it just in case this changes in the future + if not PatchPyTorchModelIO._current_task: + return original_fn(model, filename, *args, **kwargs) + tid = threading.current_thread().ident + PatchPyTorchModelIO._checkpoint_filename[tid] = filename + ret = original_fn(model, filename, *args, **kwargs) + del PatchPyTorchModelIO._checkpoint_filename[tid] + return ret + @staticmethod def _patch_lightning_io(): if PatchPyTorchModelIO.__patched_lightning: @@ -144,9 +185,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO): filename = f.name else: - filename = None + filename = PatchPyTorchModelIO.__create_default_filename() except Exception: - filename = None + filename = PatchPyTorchModelIO.__create_default_filename() # give the model a descriptive name based on the file name # noinspection PyBroadException @@ -241,3 +282,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO): pass return model + + @staticmethod + def __create_default_filename(): + tid = threading.current_thread().ident + checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid) + if checkpoint_filename: + return checkpoint_filename + counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0) + PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1 + return "default_{}_{}".format(tid, counter)