Fix code hangs when running with joblib (#1009)

This commit is contained in:
allegroai 2023-05-21 09:42:32 +03:00
parent e80d1f1ff4
commit 7c09251686
2 changed files with 13 additions and 10 deletions

View File

@ -18,7 +18,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
__patched = None
__patched_lightning = None
__patched_mmcv = None
__default_checkpoint_filename_counter = {}
@staticmethod
def update_current_task(task, **_):
@ -185,9 +184,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
filename = f.name
else:
filename = PatchPyTorchModelIO.__create_default_filename()
filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename()
except Exception:
filename = PatchPyTorchModelIO.__create_default_filename()
filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename()
# give the model a descriptive name based on the file name
# noinspection PyBroadException
@ -195,7 +194,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
model_name = Path(filename).stem if filename is not None else None
except Exception:
model_name = None
WeightsFileHandler.create_output_model(
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
@ -284,11 +282,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return model
@staticmethod
def __create_default_filename():
def __get_cached_checkpoint_filename():
tid = threading.current_thread().ident
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
if checkpoint_filename:
return checkpoint_filename
counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0)
PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1
return "default_{}_{}".format(tid, counter)
return checkpoint_filename or None

View File

@ -48,6 +48,10 @@ class PatchedJoblib(object):
joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
joblib.numpy_pickle.NumpyPickler.__init__,
PatchedJoblib._numpypickler)
joblib.memory.MemorizedFunc._cached_call = _patched_call(
joblib.memory.MemorizedFunc._cached_call,
PatchedJoblib._cached_call_recursion_guard
)
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
PatchedJoblib._patched_sk_joblib = True
@ -194,3 +198,8 @@ class PatchedJoblib(object):
"Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework))
finally:
return framework
@staticmethod
def _cached_call_recursion_guard(original_fn, *args, **kwargs):
# used just to avoid getting into the `_load` binding in the context of memory caching
return original_fn(*args, **kwargs)