mirror of
https://github.com/clearml/clearml
synced 2025-02-07 13:23:40 +00:00
Fix broken Keras binding support
This commit is contained in:
parent
9dfe36db9e
commit
1a5ed1132a
@ -1004,16 +1004,16 @@ class PatchKerasModelIO(object):
|
||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||
PatchKerasModelIO._updated_config)
|
||||
if hasattr(Sequential.from_config, '__func__'):
|
||||
Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__,
|
||||
PatchKerasModelIO._from_config)
|
||||
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
|
||||
PatchKerasModelIO._from_config))
|
||||
else:
|
||||
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
|
||||
|
||||
if Network is not None:
|
||||
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
|
||||
if hasattr(Sequential.from_config, '__func__'):
|
||||
Network.from_config.__func__ = _patched_call(Network.from_config.__func__,
|
||||
PatchKerasModelIO._from_config)
|
||||
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
|
||||
PatchKerasModelIO._from_config))
|
||||
else:
|
||||
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
|
||||
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
|
||||
import six
|
||||
|
||||
from . import get_cache_dir
|
||||
from . import get_cache_dir, running_remotely
|
||||
from .defs import SESSION_CACHE_FILE
|
||||
|
||||
|
||||
@ -34,6 +34,9 @@ class SessionCache(object):
|
||||
@classmethod
|
||||
def store_dict(cls, unique_cache_name, dict_object):
|
||||
# type: (str, dict) -> None
|
||||
# disable session cache when running in remote execution mode
|
||||
if running_remotely():
|
||||
return
|
||||
cache = cls._load_cache()
|
||||
cache[unique_cache_name] = dict_object
|
||||
cls._store_cache(cache)
|
||||
@ -41,5 +44,8 @@ class SessionCache(object):
|
||||
@classmethod
|
||||
def load_dict(cls, unique_cache_name):
|
||||
# type: (str) -> dict
|
||||
# disable session cache when running in remote execution mode
|
||||
if running_remotely():
|
||||
return {}
|
||||
cache = cls._load_cache()
|
||||
return cache.get(unique_cache_name, {}) if cache else {}
|
||||
|
@ -678,7 +678,8 @@ class OutputModel(BaseModel):
|
||||
elif self._floating_data is not None:
|
||||
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
|
||||
if _Model._unwrap_design(self._floating_data.design):
|
||||
task.set_model_config(config_text=self._floating_data.design)
|
||||
if not task.get_model_config_text():
|
||||
task.set_model_config(config_text=self._floating_data.design)
|
||||
else:
|
||||
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
|
||||
|
||||
@ -904,7 +905,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||
|
||||
if self._task:
|
||||
if self._task and not self._task.get_model_config_text():
|
||||
self._task.set_model_config(config_text=config_text)
|
||||
|
||||
if self.id:
|
||||
@ -965,8 +966,8 @@ class OutputModel(BaseModel):
|
||||
config_text = self._task.get_model_config_text()
|
||||
parent = self._task.output_model_id or self._task.input_model_id
|
||||
self._base_model.update(
|
||||
labels=labels,
|
||||
design=config_text,
|
||||
labels=self._floating_data.labels or labels,
|
||||
design=self._floating_data.design or config_text,
|
||||
task_id=self._task.id,
|
||||
project_id=self._task.project,
|
||||
parent_id=parent,
|
||||
|
Loading…
Reference in New Issue
Block a user