From 7c09251686b14cdbf5a9a64f89550034d04d9c1d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 21 May 2023 09:42:32 +0300 Subject: [PATCH] Fix code hangs when running with joblib (#1009) --- clearml/binding/frameworks/pytorch_bind.py | 14 ++++---------- clearml/binding/joblib_bind.py | 9 +++++++++ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index 9b8f8255..3cacb47a 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -18,7 +18,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO): __patched = None __patched_lightning = None __patched_mmcv = None - __default_checkpoint_filename_counter = {} @staticmethod def update_current_task(task, **_): @@ -185,9 +184,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO): filename = f.name else: - filename = PatchPyTorchModelIO.__create_default_filename() + filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename() except Exception: - filename = PatchPyTorchModelIO.__create_default_filename() + filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename() # give the model a descriptive name based on the file name # noinspection PyBroadException @@ -195,7 +194,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO): model_name = Path(filename).stem if filename is not None else None except Exception: model_name = None - WeightsFileHandler.create_output_model( obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name) @@ -284,11 +282,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): return model @staticmethod - def __create_default_filename(): + def __get_cached_checkpoint_filename(): tid = threading.current_thread().ident checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid) - if checkpoint_filename: - return checkpoint_filename - counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0) - PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1 - return "default_{}_{}".format(tid, counter) + return checkpoint_filename or None \ No newline at end of file diff --git a/clearml/binding/joblib_bind.py b/clearml/binding/joblib_bind.py index 547da002..12cde17c 100644 --- a/clearml/binding/joblib_bind.py +++ b/clearml/binding/joblib_bind.py @@ -48,6 +48,10 @@ class PatchedJoblib(object): joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call( joblib.numpy_pickle.NumpyPickler.__init__, PatchedJoblib._numpypickler) + joblib.memory.MemorizedFunc._cached_call = _patched_call( + joblib.memory.MemorizedFunc._cached_call, + PatchedJoblib._cached_call_recursion_guard + ) if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules: PatchedJoblib._patched_sk_joblib = True @@ -194,3 +198,8 @@ class PatchedJoblib(object): "Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework)) finally: return framework + + @staticmethod + def _cached_call_recursion_guard(original_fn, *args, **kwargs): + # used just to avoid getting into the `_load` binding in the context of memory caching + return original_fn(*args, **kwargs) \ No newline at end of file