mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix broken Keras binding support
This commit is contained in:
		
							parent
							
								
									9dfe36db9e
								
							
						
					
					
						commit
						1a5ed1132a
					
				| @ -1004,16 +1004,16 @@ class PatchKerasModelIO(object): | ||||
|                 Sequential._updated_config = _patched_call(Sequential._updated_config, | ||||
|                                                            PatchKerasModelIO._updated_config) | ||||
|                 if hasattr(Sequential.from_config, '__func__'): | ||||
|                     Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__, | ||||
|                                                                     PatchKerasModelIO._from_config) | ||||
|                     Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__, | ||||
|                                                          PatchKerasModelIO._from_config)) | ||||
|                 else: | ||||
|                     Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config) | ||||
| 
 | ||||
|             if Network is not None: | ||||
|                 Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config) | ||||
|                 if hasattr(Sequential.from_config, '__func__'): | ||||
|                     Network.from_config.__func__ = _patched_call(Network.from_config.__func__, | ||||
|                                                                  PatchKerasModelIO._from_config) | ||||
|                     Network.from_config = classmethod(_patched_call(Network.from_config.__func__, | ||||
|                                                                     PatchKerasModelIO._from_config)) | ||||
|                 else: | ||||
|                     Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config) | ||||
|                 Network.save = _patched_call(Network.save, PatchKerasModelIO._save) | ||||
|  | ||||
| @ -2,7 +2,7 @@ import json | ||||
| 
 | ||||
| import six | ||||
| 
 | ||||
| from . import get_cache_dir | ||||
| from . import get_cache_dir, running_remotely | ||||
| from .defs import SESSION_CACHE_FILE | ||||
| 
 | ||||
| 
 | ||||
| @ -34,6 +34,9 @@ class SessionCache(object): | ||||
|     @classmethod | ||||
|     def store_dict(cls, unique_cache_name, dict_object): | ||||
|         # type: (str, dict) -> None | ||||
|         # disable session cache when running in remote execution mode | ||||
|         if running_remotely(): | ||||
|             return | ||||
|         cache = cls._load_cache() | ||||
|         cache[unique_cache_name] = dict_object | ||||
|         cls._store_cache(cache) | ||||
| @ -41,5 +44,8 @@ class SessionCache(object): | ||||
|     @classmethod | ||||
|     def load_dict(cls, unique_cache_name): | ||||
|         # type: (str) -> dict | ||||
|         # disable session cache when running in remote execution mode | ||||
|         if running_remotely(): | ||||
|             return {} | ||||
|         cache = cls._load_cache() | ||||
|         return cache.get(unique_cache_name, {}) if cache else {} | ||||
|  | ||||
| @ -678,7 +678,8 @@ class OutputModel(BaseModel): | ||||
|         elif self._floating_data is not None: | ||||
|             # we copy configuration / labels if they exist, obviously someone wants them as the output base model | ||||
|             if _Model._unwrap_design(self._floating_data.design): | ||||
|                 task.set_model_config(config_text=self._floating_data.design) | ||||
|                 if not task.get_model_config_text(): | ||||
|                     task.set_model_config(config_text=self._floating_data.design) | ||||
|             else: | ||||
|                 self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text()) | ||||
| 
 | ||||
| @ -904,7 +905,7 @@ class OutputModel(BaseModel): | ||||
| 
 | ||||
|         config_text = self._resolve_config(config_text=config_text, config_dict=config_dict) | ||||
| 
 | ||||
|         if self._task: | ||||
|         if self._task and not self._task.get_model_config_text(): | ||||
|             self._task.set_model_config(config_text=config_text) | ||||
| 
 | ||||
|         if self.id: | ||||
| @ -965,8 +966,8 @@ class OutputModel(BaseModel): | ||||
|         config_text = self._task.get_model_config_text() | ||||
|         parent = self._task.output_model_id or self._task.input_model_id | ||||
|         self._base_model.update( | ||||
|             labels=labels, | ||||
|             design=config_text, | ||||
|             labels=self._floating_data.labels or labels, | ||||
|             design=self._floating_data.design or config_text, | ||||
|             task_id=self._task.id, | ||||
|             project_id=self._task.project, | ||||
|             parent_id=parent, | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai