mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Add ModelInfo.weights_object()
for store callback access to the actual model object being stored (valid for both pre/post save calls, otherwise None
)
This commit is contained in:
parent
1729b9019f
commit
a841ab7450
@ -91,6 +91,9 @@ class WeightsFileHandler(object):
|
||||
self.local_model_id = local_model_id
|
||||
self.framework = framework
|
||||
self.task = task
|
||||
# temporary store reference to the actual model/weights object that was saved.
|
||||
# only valid for store callbacks
|
||||
self.weights_object = None
|
||||
|
||||
@staticmethod
|
||||
def _add_callback(func, target):
|
||||
@ -369,6 +372,8 @@ class WeightsFileHandler(object):
|
||||
else:
|
||||
target_filename = Path(files[0]).name
|
||||
|
||||
# pass model object to ModelInfo object, maybe someone can use it
|
||||
model_info.weights_object = model
|
||||
# call pre model callback functions
|
||||
model_info.upload_filename = target_filename
|
||||
for cb in list(WeightsFileHandler._model_pre_callbacks.values()):
|
||||
@ -377,6 +382,8 @@ class WeightsFileHandler(object):
|
||||
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||
except Exception:
|
||||
pass
|
||||
# making sure we do not store an additional reference to the original model
|
||||
model_info.weights_object = None
|
||||
|
||||
# if callbacks force us to leave they return None
|
||||
if model_info is None:
|
||||
@ -430,6 +437,8 @@ class WeightsFileHandler(object):
|
||||
# WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||
|
||||
model_info.model = trains_out_model
|
||||
# pass model object to ModelInfo object, maybe someone can use it
|
||||
model_info.weights_object = model
|
||||
# call post model callback functions
|
||||
for cb in list(WeightsFileHandler._model_post_callbacks.values()):
|
||||
# noinspection PyBroadException
|
||||
@ -437,6 +446,9 @@ class WeightsFileHandler(object):
|
||||
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||
except Exception:
|
||||
pass
|
||||
# making sure we do not store an additional reference to the original model
|
||||
model_info.weights_object = None
|
||||
|
||||
trains_out_model = model_info.model
|
||||
target_filename = model_info.upload_filename
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user