mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix code hangs when running with joblib (#1009)
This commit is contained in:
parent
e80d1f1ff4
commit
7c09251686
@ -18,7 +18,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
__patched = None
|
__patched = None
|
||||||
__patched_lightning = None
|
__patched_lightning = None
|
||||||
__patched_mmcv = None
|
__patched_mmcv = None
|
||||||
__default_checkpoint_filename_counter = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
@ -185,9 +184,9 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
filename = f.name
|
filename = f.name
|
||||||
else:
|
else:
|
||||||
filename = PatchPyTorchModelIO.__create_default_filename()
|
filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename()
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = PatchPyTorchModelIO.__create_default_filename()
|
filename = PatchPyTorchModelIO.__get_cached_checkpoint_filename()
|
||||||
|
|
||||||
# give the model a descriptive name based on the file name
|
# give the model a descriptive name based on the file name
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -195,7 +194,6 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
model_name = Path(filename).stem if filename is not None else None
|
model_name = Path(filename).stem if filename is not None else None
|
||||||
except Exception:
|
except Exception:
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
WeightsFileHandler.create_output_model(
|
WeightsFileHandler.create_output_model(
|
||||||
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
|
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
|
||||||
|
|
||||||
@ -284,11 +282,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __create_default_filename():
|
def __get_cached_checkpoint_filename():
|
||||||
tid = threading.current_thread().ident
|
tid = threading.current_thread().ident
|
||||||
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
|
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
|
||||||
if checkpoint_filename:
|
return checkpoint_filename or None
|
||||||
return checkpoint_filename
|
|
||||||
counter = PatchPyTorchModelIO.__default_checkpoint_filename_counter.setdefault(tid, 0)
|
|
||||||
PatchPyTorchModelIO.__default_checkpoint_filename_counter[tid] += 1
|
|
||||||
return "default_{}_{}".format(tid, counter)
|
|
@ -48,6 +48,10 @@ class PatchedJoblib(object):
|
|||||||
joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
|
joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
|
||||||
joblib.numpy_pickle.NumpyPickler.__init__,
|
joblib.numpy_pickle.NumpyPickler.__init__,
|
||||||
PatchedJoblib._numpypickler)
|
PatchedJoblib._numpypickler)
|
||||||
|
joblib.memory.MemorizedFunc._cached_call = _patched_call(
|
||||||
|
joblib.memory.MemorizedFunc._cached_call,
|
||||||
|
PatchedJoblib._cached_call_recursion_guard
|
||||||
|
)
|
||||||
|
|
||||||
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
||||||
PatchedJoblib._patched_sk_joblib = True
|
PatchedJoblib._patched_sk_joblib = True
|
||||||
@ -194,3 +198,8 @@ class PatchedJoblib(object):
|
|||||||
"Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework))
|
"Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework))
|
||||||
finally:
|
finally:
|
||||||
return framework
|
return framework
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cached_call_recursion_guard(original_fn, *args, **kwargs):
|
||||||
|
# used just to avoid getting into the `_load` binding in the context of memory caching
|
||||||
|
return original_fn(*args, **kwargs)
|
Loading…
Reference in New Issue
Block a user