mirror of
https://github.com/clearml/clearml
synced 2025-03-14 23:58:25 +00:00
Fix broken Keras binding support
This commit is contained in:
parent
9dfe36db9e
commit
1a5ed1132a
@ -1004,16 +1004,16 @@ class PatchKerasModelIO(object):
|
|||||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||||
PatchKerasModelIO._updated_config)
|
PatchKerasModelIO._updated_config)
|
||||||
if hasattr(Sequential.from_config, '__func__'):
|
if hasattr(Sequential.from_config, '__func__'):
|
||||||
Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__,
|
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
|
||||||
PatchKerasModelIO._from_config)
|
PatchKerasModelIO._from_config))
|
||||||
else:
|
else:
|
||||||
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
|
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
|
||||||
|
|
||||||
if Network is not None:
|
if Network is not None:
|
||||||
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
|
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
|
||||||
if hasattr(Sequential.from_config, '__func__'):
|
if hasattr(Sequential.from_config, '__func__'):
|
||||||
Network.from_config.__func__ = _patched_call(Network.from_config.__func__,
|
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
|
||||||
PatchKerasModelIO._from_config)
|
PatchKerasModelIO._from_config))
|
||||||
else:
|
else:
|
||||||
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
|
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
|
||||||
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
|
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from . import get_cache_dir
|
from . import get_cache_dir, running_remotely
|
||||||
from .defs import SESSION_CACHE_FILE
|
from .defs import SESSION_CACHE_FILE
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +34,9 @@ class SessionCache(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def store_dict(cls, unique_cache_name, dict_object):
|
def store_dict(cls, unique_cache_name, dict_object):
|
||||||
# type: (str, dict) -> None
|
# type: (str, dict) -> None
|
||||||
|
# disable session cache when running in remote execution mode
|
||||||
|
if running_remotely():
|
||||||
|
return
|
||||||
cache = cls._load_cache()
|
cache = cls._load_cache()
|
||||||
cache[unique_cache_name] = dict_object
|
cache[unique_cache_name] = dict_object
|
||||||
cls._store_cache(cache)
|
cls._store_cache(cache)
|
||||||
@ -41,5 +44,8 @@ class SessionCache(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_dict(cls, unique_cache_name):
|
def load_dict(cls, unique_cache_name):
|
||||||
# type: (str) -> dict
|
# type: (str) -> dict
|
||||||
|
# disable session cache when running in remote execution mode
|
||||||
|
if running_remotely():
|
||||||
|
return {}
|
||||||
cache = cls._load_cache()
|
cache = cls._load_cache()
|
||||||
return cache.get(unique_cache_name, {}) if cache else {}
|
return cache.get(unique_cache_name, {}) if cache else {}
|
||||||
|
@ -678,6 +678,7 @@ class OutputModel(BaseModel):
|
|||||||
elif self._floating_data is not None:
|
elif self._floating_data is not None:
|
||||||
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
||||||
if _Model._unwrap_design(self._floating_data.design):
|
if _Model._unwrap_design(self._floating_data.design):
|
||||||
|
if not task.get_model_config_text():
|
||||||
task.set_model_config(config_text=self._floating_data.design)
|
task.set_model_config(config_text=self._floating_data.design)
|
||||||
else:
|
else:
|
||||||
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
|
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
|
||||||
@ -904,7 +905,7 @@ class OutputModel(BaseModel):
|
|||||||
|
|
||||||
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||||
|
|
||||||
if self._task:
|
if self._task and not self._task.get_model_config_text():
|
||||||
self._task.set_model_config(config_text=config_text)
|
self._task.set_model_config(config_text=config_text)
|
||||||
|
|
||||||
if self.id:
|
if self.id:
|
||||||
@ -965,8 +966,8 @@ class OutputModel(BaseModel):
|
|||||||
config_text = self._task.get_model_config_text()
|
config_text = self._task.get_model_config_text()
|
||||||
parent = self._task.output_model_id or self._task.input_model_id
|
parent = self._task.output_model_id or self._task.input_model_id
|
||||||
self._base_model.update(
|
self._base_model.update(
|
||||||
labels=labels,
|
labels=self._floating_data.labels or labels,
|
||||||
design=config_text,
|
design=self._floating_data.design or config_text,
|
||||||
task_id=self._task.id,
|
task_id=self._task.id,
|
||||||
project_id=self._task.project,
|
project_id=self._task.project,
|
||||||
parent_id=parent,
|
parent_id=parent,
|
||||||
|
Loading…
Reference in New Issue
Block a user