Fix broken Keras binding support

This commit is contained in:
allegroai 2019-09-09 21:49:55 +03:00
parent 9dfe36db9e
commit 1a5ed1132a
3 changed files with 16 additions and 9 deletions

View File

@ -1004,16 +1004,16 @@ class PatchKerasModelIO(object):
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config)
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Network.from_config.__func__ = _patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config)
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)

View File

@ -2,7 +2,7 @@ import json
import six
from . import get_cache_dir
from . import get_cache_dir, running_remotely
from .defs import SESSION_CACHE_FILE
@ -34,6 +34,9 @@ class SessionCache(object):
@classmethod
def store_dict(cls, unique_cache_name, dict_object):
# type: (str, dict) -> None
# disable session cache when running in remote execution mode
if running_remotely():
return
cache = cls._load_cache()
cache[unique_cache_name] = dict_object
cls._store_cache(cache)
@ -41,5 +44,8 @@ class SessionCache(object):
@classmethod
def load_dict(cls, unique_cache_name):
# type: (str) -> dict
# disable session cache when running in remote execution mode
if running_remotely():
return {}
cache = cls._load_cache()
return cache.get(unique_cache_name, {}) if cache else {}

View File

@ -678,6 +678,7 @@ class OutputModel(BaseModel):
elif self._floating_data is not None:
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
if _Model._unwrap_design(self._floating_data.design):
if not task.get_model_config_text():
task.set_model_config(config_text=self._floating_data.design)
else:
self._floating_data.design = _Model._wrap_design(self._task.get_model_config_text())
@ -904,7 +905,7 @@ class OutputModel(BaseModel):
config_text = self._resolve_config(config_text=config_text, config_dict=config_dict)
if self._task:
if self._task and not self._task.get_model_config_text():
self._task.set_model_config(config_text=config_text)
if self.id:
@ -965,8 +966,8 @@ class OutputModel(BaseModel):
config_text = self._task.get_model_config_text()
parent = self._task.output_model_id or self._task.input_model_id
self._base_model.update(
labels=labels,
design=config_text,
labels=self._floating_data.labels or labels,
design=self._floating_data.design or config_text,
task_id=self._task.id,
project_id=self._task.project,
parent_id=parent,