mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Fix filtering on hyperparams (https://github.com/allegroai/clearml/issues/385, https://clearml.slack.com/archives/CTK20V944/p1626600582284700)
This commit is contained in:
parent
09ab2af34c
commit
56aea1ffb8
@ -1,11 +1,10 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Tuple
|
from typing import Sequence, Tuple, Optional
|
||||||
|
|
||||||
import dpath
|
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.database.model.task.task import Task
|
from apiserver.database.model.task.task import Task
|
||||||
from apiserver.tools import safe_get
|
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
|
||||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
|
|||||||
tf_define_section = "TF_DEFINE"
|
tf_define_section = "TF_DEFINE"
|
||||||
|
|
||||||
|
|
||||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
|
||||||
"""
|
"""
|
||||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||||
"""
|
"""
|
||||||
@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
|||||||
return removed
|
return removed
|
||||||
|
|
||||||
|
|
||||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
|
||||||
"""
|
"""
|
||||||
Remove the legacy params from the data dict and return the number of removed params
|
Remove the legacy params from the data dict and return the number of removed params
|
||||||
If the path not found then return 0
|
If the path not found then return 0
|
||||||
@ -71,9 +70,11 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
if with_sections:
|
if with_sections:
|
||||||
return itertools.chain.from_iterable(
|
return list(
|
||||||
|
itertools.chain.from_iterable(
|
||||||
_get_legacy_params(section_data) for section_data in data.values()
|
_get_legacy_params(section_data) for section_data in data.values()
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||||
@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
|||||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||||
"""
|
"""
|
||||||
for old_params_field, new_params_field, default_section in (
|
for old_params_field, new_params_field, default_section in (
|
||||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
|
||||||
("execution/model_desc", "configuration", None),
|
(("execution", "model_desc"), "configuration", None),
|
||||||
):
|
):
|
||||||
legacy_params = safe_get(fields, old_params_field)
|
legacy_params = nested_get(fields, old_params_field)
|
||||||
if legacy_params is None:
|
if legacy_params is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not safe_get(fields, new_params_field)
|
not fields.get(new_params_field)
|
||||||
and previous_task
|
and previous_task
|
||||||
and previous_task[new_params_field]
|
and previous_task[new_params_field]
|
||||||
):
|
):
|
||||||
@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
|||||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||||
if section is not None:
|
if section is not None:
|
||||||
new_param["section"] = section
|
new_param["section"] = section
|
||||||
dpath.new(fields, new_path, new_param)
|
nested_set(fields, new_path, new_param)
|
||||||
dpath.delete(fields, old_params_field)
|
nested_delete(fields, old_params_field)
|
||||||
|
|
||||||
for param_field in ("hyperparams", "configuration"):
|
for param_field in ("hyperparams", "configuration"):
|
||||||
params = safe_get(fields, param_field)
|
params = fields.get(param_field)
|
||||||
if params:
|
if params:
|
||||||
escaped_params = {
|
escaped_params = {
|
||||||
ParameterKeyEscaper.escape(key): {
|
ParameterKeyEscaper.escape(key): {
|
||||||
@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
|||||||
else value
|
else value
|
||||||
for key, value in params.items()
|
for key, value in params.items()
|
||||||
}
|
}
|
||||||
dpath.set(fields, param_field, escaped_params)
|
fields[param_field] = escaped_params
|
||||||
|
|
||||||
|
|
||||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||||
@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
|||||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||||
"""
|
"""
|
||||||
for param_field in ("hyperparams", "configuration"):
|
for param_field in ("hyperparams", "configuration"):
|
||||||
params = safe_get(fields, param_field)
|
params = fields.get(param_field)
|
||||||
if params:
|
if params:
|
||||||
unescaped_params = {
|
unescaped_params = {
|
||||||
ParameterKeyEscaper.unescape(key): {
|
ParameterKeyEscaper.unescape(key): {
|
||||||
@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
|||||||
else value
|
else value
|
||||||
for key, value in params.items()
|
for key, value in params.items()
|
||||||
}
|
}
|
||||||
dpath.set(fields, param_field, unescaped_params)
|
fields[param_field] = unescaped_params
|
||||||
|
|
||||||
if copy_to_legacy:
|
if copy_to_legacy:
|
||||||
for new_params_field, old_params_field, use_sections in (
|
for new_params_field, old_params_field, use_sections in (
|
||||||
(f"hyperparams", "execution/parameters", True),
|
("hyperparams", ("execution", "parameters"), True),
|
||||||
(f"configuration", "execution/model_desc", False),
|
("configuration", ("execution", "model_desc"), False),
|
||||||
):
|
):
|
||||||
legacy_params = _get_legacy_params(
|
legacy_params = _get_legacy_params(
|
||||||
safe_get(fields, new_params_field), with_sections=use_sections
|
fields.get(new_params_field), with_sections=use_sections
|
||||||
)
|
)
|
||||||
if legacy_params:
|
if legacy_params:
|
||||||
dpath.new(
|
nested_set(
|
||||||
fields,
|
fields,
|
||||||
old_params_field,
|
old_params_field,
|
||||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||||
@ -174,7 +175,7 @@ def _process_path(path: str):
|
|||||||
Need to unescape and apply a full mongo escaping
|
Need to unescape and apply a full mongo escaping
|
||||||
"""
|
"""
|
||||||
parts = path.split(".")
|
parts = path.split(".")
|
||||||
if len(parts) < 2 or len(parts) > 3:
|
if len(parts) < 2 or len(parts) > 4:
|
||||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||||
return ".".join(
|
return ".".join(
|
||||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||||
@ -184,7 +185,7 @@ def _process_path(path: str):
|
|||||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||||
for old_prefix, new_prefix in (
|
for old_prefix, new_prefix in (
|
||||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||||
("execution.model_desc", f"configuration"),
|
("execution.model_desc", "configuration"),
|
||||||
("execution.docker_cmd", "container")
|
("execution.docker_cmd", "container")
|
||||||
):
|
):
|
||||||
path: str
|
path: str
|
||||||
|
@ -219,6 +219,7 @@ class Task(AttributedDocument):
|
|||||||
"status",
|
"status",
|
||||||
"project",
|
"project",
|
||||||
"parent",
|
"parent",
|
||||||
|
"hyperparams.*",
|
||||||
),
|
),
|
||||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||||
datetime_fields=("status_changed", "last_update"),
|
datetime_fields=("status_changed", "last_update"),
|
||||||
|
Loading…
Reference in New Issue
Block a user