mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Add ModelInfo.weights_object()
for store callback access to the actual model object being stored (valid for both pre/post save calls, otherwise None
)
This commit is contained in:
parent
1729b9019f
commit
a841ab7450
@ -91,6 +91,9 @@ class WeightsFileHandler(object):
|
|||||||
self.local_model_id = local_model_id
|
self.local_model_id = local_model_id
|
||||||
self.framework = framework
|
self.framework = framework
|
||||||
self.task = task
|
self.task = task
|
||||||
|
# temporary store reference to the actual model/weights object that was saved.
|
||||||
|
# only valid for store callbacks
|
||||||
|
self.weights_object = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_callback(func, target):
|
def _add_callback(func, target):
|
||||||
@ -369,6 +372,8 @@ class WeightsFileHandler(object):
|
|||||||
else:
|
else:
|
||||||
target_filename = Path(files[0]).name
|
target_filename = Path(files[0]).name
|
||||||
|
|
||||||
|
# pass model object to ModelInfo object, maybe someone can use it
|
||||||
|
model_info.weights_object = model
|
||||||
# call pre model callback functions
|
# call pre model callback functions
|
||||||
model_info.upload_filename = target_filename
|
model_info.upload_filename = target_filename
|
||||||
for cb in list(WeightsFileHandler._model_pre_callbacks.values()):
|
for cb in list(WeightsFileHandler._model_pre_callbacks.values()):
|
||||||
@ -377,6 +382,8 @@ class WeightsFileHandler(object):
|
|||||||
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# making sure we do not store an additional reference to the original model
|
||||||
|
model_info.weights_object = None
|
||||||
|
|
||||||
# if callbacks force us to leave they return None
|
# if callbacks force us to leave they return None
|
||||||
if model_info is None:
|
if model_info is None:
|
||||||
@ -430,6 +437,8 @@ class WeightsFileHandler(object):
|
|||||||
# WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
# WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||||
|
|
||||||
model_info.model = trains_out_model
|
model_info.model = trains_out_model
|
||||||
|
# pass model object to ModelInfo object, maybe someone can use it
|
||||||
|
model_info.weights_object = model
|
||||||
# call post model callback functions
|
# call post model callback functions
|
||||||
for cb in list(WeightsFileHandler._model_post_callbacks.values()):
|
for cb in list(WeightsFileHandler._model_post_callbacks.values()):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -437,6 +446,9 @@ class WeightsFileHandler(object):
|
|||||||
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# making sure we do not store an additional reference to the original model
|
||||||
|
model_info.weights_object = None
|
||||||
|
|
||||||
trains_out_model = model_info.model
|
trains_out_model = model_info.model
|
||||||
target_filename = model_info.upload_filename
|
target_filename = model_info.upload_filename
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user