Add ModelInfo.weights_object() for store callback access to the actual model object being stored (valid for both pre/post save calls, otherwise None)

This commit is contained in:
allegroai 2022-11-21 16:31:58 +02:00
parent 1729b9019f
commit a841ab7450

View File

@ -91,6 +91,9 @@ class WeightsFileHandler(object):
self.local_model_id = local_model_id
self.framework = framework
self.task = task
# temporary store reference to the actual model/weights object that was saved.
# only valid for store callbacks
self.weights_object = None
@staticmethod
def _add_callback(func, target):
@ -369,6 +372,8 @@ class WeightsFileHandler(object):
else:
target_filename = Path(files[0]).name
# pass model object to ModelInfo object, maybe someone can use it
model_info.weights_object = model
# call pre model callback functions
model_info.upload_filename = target_filename
for cb in list(WeightsFileHandler._model_pre_callbacks.values()):
@ -377,6 +382,8 @@ class WeightsFileHandler(object):
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
except Exception:
pass
# making sure we do not store an additional reference to the original model
model_info.weights_object = None
# if callbacks force us to leave they return None
if model_info is None:
@ -430,6 +437,8 @@ class WeightsFileHandler(object):
# WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
model_info.model = trains_out_model
# pass model object to ModelInfo object, maybe someone can use it
model_info.weights_object = model
# call post model callback functions
for cb in list(WeightsFileHandler._model_post_callbacks.values()):
# noinspection PyBroadException
@ -437,6 +446,9 @@ class WeightsFileHandler(object):
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
except Exception:
pass
# making sure we do not store an additional reference to the original model
model_info.weights_object = None
trains_out_model = model_info.model
target_filename = model_info.upload_filename