allegroai 2021-07-25 13:55:09 +03:00
parent 09ab2af34c
commit 56aea1ffb8
2 changed files with 25 additions and 23 deletions

View File

@ -1,11 +1,10 @@
import itertools import itertools
from typing import Sequence, Tuple from typing import Sequence, Tuple, Optional
import dpath
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE" tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]: def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
""" """
Return parameter section and name. The section is either TF_DEFINE or the default one Return parameter section and name. The section is either TF_DEFINE or the default one
""" """
@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
return removed return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]: def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
""" """
Remove the legacy params from the data dict and return the number of removed params Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0 If the path not found then return 0
@ -71,9 +70,11 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
return [] return []
if with_sections: if with_sections:
return itertools.chain.from_iterable( return list(
itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values() _get_legacy_params(section_data) for section_data in data.values()
) )
)
return [ return [
param for param in data.values() if param.get("type") == hyperparams_legacy_type param for param in data.values() if param.get("type") == hyperparams_legacy_type
@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
Escape all the section and param names for hyper params and configuration to make it mongo sage Escape all the section and param names for hyper params and configuration to make it mongo sage
""" """
for old_params_field, new_params_field, default_section in ( for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section), (("execution", "parameters"), "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None), (("execution", "model_desc"), "configuration", None),
): ):
legacy_params = safe_get(fields, old_params_field) legacy_params = nested_get(fields, old_params_field)
if legacy_params is None: if legacy_params is None:
continue continue
if ( if (
not safe_get(fields, new_params_field) not fields.get(new_params_field)
and previous_task and previous_task
and previous_task[new_params_field] and previous_task[new_params_field]
): ):
@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value)) new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None: if section is not None:
new_param["section"] = section new_param["section"] = section
dpath.new(fields, new_path, new_param) nested_set(fields, new_path, new_param)
dpath.delete(fields, old_params_field) nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"): for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field) params = fields.get(param_field)
if params: if params:
escaped_params = { escaped_params = {
ParameterKeyEscaper.escape(key): { ParameterKeyEscaper.escape(key): {
@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
else value else value
for key, value in params.items() for key, value in params.items()
} }
dpath.set(fields, param_field, escaped_params) fields[param_field] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False): def params_unprepare_from_saved(fields, copy_to_legacy=False):
@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
""" """
for param_field in ("hyperparams", "configuration"): for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field) params = fields.get(param_field)
if params: if params:
unescaped_params = { unescaped_params = {
ParameterKeyEscaper.unescape(key): { ParameterKeyEscaper.unescape(key): {
@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
else value else value
for key, value in params.items() for key, value in params.items()
} }
dpath.set(fields, param_field, unescaped_params) fields[param_field] = unescaped_params
if copy_to_legacy: if copy_to_legacy:
for new_params_field, old_params_field, use_sections in ( for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True), ("hyperparams", ("execution", "parameters"), True),
(f"configuration", "execution/model_desc", False), ("configuration", ("execution", "model_desc"), False),
): ):
legacy_params = _get_legacy_params( legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections fields.get(new_params_field), with_sections=use_sections
) )
if legacy_params: if legacy_params:
dpath.new( nested_set(
fields, fields,
old_params_field, old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params}, {_get_full_param_name(p): p["value"] for p in legacy_params},
@ -174,7 +175,7 @@ def _process_path(path: str):
Need to unescape and apply a full mongo escaping Need to unescape and apply a full mongo escaping
""" """
parts = path.split(".") parts = path.split(".")
if len(parts) < 2 or len(parts) > 3: if len(parts) < 2 or len(parts) > 4:
raise errors.bad_request.ValidationError("invalid task field", path=path) raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join( return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
@ -184,7 +185,7 @@ def _process_path(path: str):
def escape_paths(paths: Sequence[str]) -> Sequence[str]: def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in ( for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"), ("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"), ("execution.model_desc", "configuration"),
("execution.docker_cmd", "container") ("execution.docker_cmd", "container")
): ):
path: str path: str

View File

@ -219,6 +219,7 @@ class Task(AttributedDocument):
"status", "status",
"project", "project",
"parent", "parent",
"hyperparams.*",
), ),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"), range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"), datetime_fields=("status_changed", "last_update"),