mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add ModelInfo.weights_object() for store callback access to the actual model object being stored (valid for both pre/post save calls, otherwise None)
				
					
				
			This commit is contained in:
		
							parent
							
								
									1729b9019f
								
							
						
					
					
						commit
						a841ab7450
					
				| @ -91,6 +91,9 @@ class WeightsFileHandler(object): | ||||
|             self.local_model_id = local_model_id | ||||
|             self.framework = framework | ||||
|             self.task = task | ||||
|             # temporary store reference to the actual model/weights object that was saved. | ||||
|             # only valid for store callbacks | ||||
|             self.weights_object = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _add_callback(func, target): | ||||
| @ -369,6 +372,8 @@ class WeightsFileHandler(object): | ||||
|             else: | ||||
|                 target_filename = Path(files[0]).name | ||||
| 
 | ||||
|             # pass model object to ModelInfo object, maybe someone can use it | ||||
|             model_info.weights_object = model | ||||
|             # call pre model callback functions | ||||
|             model_info.upload_filename = target_filename | ||||
|             for cb in list(WeightsFileHandler._model_pre_callbacks.values()): | ||||
| @ -377,6 +382,8 @@ class WeightsFileHandler(object): | ||||
|                     model_info = cb(WeightsFileHandler.CallbackType.save, model_info) | ||||
|                 except Exception: | ||||
|                     pass | ||||
|             # making sure we do not store an additional reference to the original model | ||||
|             model_info.weights_object = None | ||||
| 
 | ||||
|             # if callbacks force us to leave they return None | ||||
|             if model_info is None: | ||||
| @ -430,6 +437,8 @@ class WeightsFileHandler(object): | ||||
|                 #     WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) | ||||
| 
 | ||||
|             model_info.model = trains_out_model | ||||
|             # pass model object to ModelInfo object, maybe someone can use it | ||||
|             model_info.weights_object = model | ||||
|             # call post model callback functions | ||||
|             for cb in list(WeightsFileHandler._model_post_callbacks.values()): | ||||
|                 # noinspection PyBroadException | ||||
| @ -437,6 +446,9 @@ class WeightsFileHandler(object): | ||||
|                     model_info = cb(WeightsFileHandler.CallbackType.save, model_info) | ||||
|                 except Exception: | ||||
|                     pass | ||||
|             # making sure we do not store an additional reference to the original model | ||||
|             model_info.weights_object = None | ||||
| 
 | ||||
|             trains_out_model = model_info.model | ||||
|             target_filename = model_info.upload_filename | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai