mirror of
https://github.com/clearml/clearml
synced 2025-02-07 13:23:40 +00:00
Fix code hangs when running with joblib (#1009)
This commit is contained in:
parent
e80d1f1ff4
commit
7c09251686
@ -18,7 +18,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
__patched = None
|
||||
__patched_lightning = None
|
||||
__patched_mmcv = None
|
||||
__default_checkpoint_filename_counter = {}
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
@ -185,9 +184,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
|
||||
filename = f.name
|
||||
else:
|
||||
filename = PatchPyTorchModelIO.__create_default_filename()
|
||||
filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename()
|
||||
except Exception:
|
||||
filename = PatchPyTorchModelIO.__create_default_filename()
|
||||
filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename()
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
@ -195,7 +194,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
model_name = Path(filename).stem if filename is not None else None
|
||||
except Exception:
|
||||
model_name = None
|
||||
|
||||
WeightsFileHandler.create_output_model(
|
||||
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
|
||||
|
||||
@ -284,11 +282,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def __create_default_filename():
|
||||
def __get_cached_checkpoint_filename():
|
||||
tid = threading.current_thread().ident
|
||||
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
|
||||
if checkpoint_filename:
|
||||
return checkpoint_filename
|
||||
counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0)
|
||||
PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1
|
||||
return "default_{}_{}".format(tid, counter)
|
||||
return checkpoint_filename or None
|
@ -48,6 +48,10 @@ class PatchedJoblib(object):
|
||||
joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
|
||||
joblib.numpy_pickle.NumpyPickler.__init__,
|
||||
PatchedJoblib._numpypickler)
|
||||
joblib.memory.MemorizedFunc._cached_call = _patched_call(
|
||||
joblib.memory.MemorizedFunc._cached_call,
|
||||
PatchedJoblib._cached_call_recursion_guard
|
||||
)
|
||||
|
||||
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
||||
PatchedJoblib._patched_sk_joblib = True
|
||||
@ -194,3 +198,8 @@ class PatchedJoblib(object):
|
||||
"Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework))
|
||||
finally:
|
||||
return framework
|
||||
|
||||
@staticmethod
|
||||
def _cached_call_recursion_guard(original_fn, *args, **kwargs):
|
||||
# used just to avoid getting into the `_load` binding in the context of memory caching
|
||||
return original_fn(*args, **kwargs)
|
Loading…
Reference in New Issue
Block a user