Add support for Task hyper-parameter sections and meta-data

Add new Task configuration section
This commit is contained in:
allegroai 2020-08-10 08:45:25 +03:00
parent 42ba696518
commit 8c7e230898
14 changed files with 1076 additions and 107 deletions

View File

@ -4,5 +4,4 @@ from .utils import (
update_project_time,
validate_status_change,
split_by,
ParameterKeyEscaper,
)

View File

@ -0,0 +1,229 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict
from boltons import iterutils
from apierrors import errors
from apimodels.tasks import (
HyperParamKey,
HyperParamItem,
ReplaceHyperparams,
Configuration,
)
from bll.task import TaskBLL
from config import config
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
from utilities.parameter_key_escaper import ParameterKeyEscaper
log = config.logger(__file__)
task_bll = TaskBLL()
class HyperParams:
_properties_section = "properties"
@classmethod
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
for task in tasks
}
@classmethod
def _get_params_list(
cls, items: Dict[str, Dict[str, ParamsItem]]
) -> Sequence[dict]:
ret = list(chain.from_iterable(v.values() for v in items.values()))
return [
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
]
@classmethod
def _normalize_params(cls, params: Sequence) -> bool:
"""
Lower case properties section and return True if it is the only section
"""
for p in params:
if p.section.lower() == cls._properties_section:
p.section = cls._properties_section
return all(p.section == cls._properties_section for p in params)
@classmethod
def delete_params(
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return task.update(**delete_cmds, last_update=datetime.utcnow())
@classmethod
def edit_params(
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{section}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[f"set__hyperparams__{section}__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
sections = iterutils.bucketize(items, key=attrgetter("section"))
return {
ParameterKeyEscaper.escape(section): {
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
for param in params
}
for section, params in sections.items()
}
@classmethod
def get_configurations(
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
) -> Dict[str, dict]:
only = ["id"]
if names:
only.extend(
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
)
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {
"configuration": [
c.to_proper_dict()
for c in sorted(task.configuration.values(), key=attrgetter("name"))
]
}
for task in tasks
}
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
) -> Dict[str, list]:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration=Sequence[str]
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())
@staticmethod
def _get_task_for_update(
company: str, id: str, allow_all_statuses: bool = False
) -> Task:
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=id)
if allow_all_statuses:
return task
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task

View File

@ -0,0 +1,201 @@
import itertools
from typing import Sequence, Tuple
import dpath
from apierrors import errors
from database.model.task.task import Task
from tools import safe_get
from utilities.parameter_key_escaper import ParameterKeyEscaper
hyperparams_default_section = "Args"
hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
if default_section is None:
return None, full_name
section, _, name = full_name.partition("/")
if section != tf_define_section:
return default_section, full_name
if not name:
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
return section, name
def _get_full_param_name(param: dict) -> str:
section = param.get("section")
if section != tf_define_section:
return param["name"]
return "/".join((section, param["name"]))
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
"""
removed = 0
if not data:
return removed
if with_sections:
for section, section_data in list(data.items()):
removed += _remove_legacy_params(section_data)
if not section_data:
"""If section is empty after removing legacy params then delete it"""
del data[section]
else:
for key, param in list(data.items()):
if param.get("type") == hyperparams_legacy_type:
removed += 1
del data[key]
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
"""
if not data:
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
return [
param for param in data.values() if param.get("type") == hyperparams_legacy_type
]
def params_prepare_for_save(fields: dict, previous_task: Task = None):
"""
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
and previous_task
and previous_task[new_params_field]
):
previous_data = previous_task.to_proper_dict().get(new_params_field)
removed = _remove_legacy_params(
previous_data, with_sections=default_section is not None
)
if not legacy_params and not removed:
# if we only need to delete legacy fields from the db
# but they are not there then there is no point to proceed
continue
fields_update = {new_params_field: previous_data}
params_unprepare_from_saved(fields_update)
fields.update(fields_update)
for full_name, value in legacy_params.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_params_field, section, name)))
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
ParameterKeyEscaper.escape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
}
dpath.set(fields, param_field, escaped_params)
def params_unprepare_from_saved(fields, copy_to_legacy=False):
"""
Unescape all section and param names for hyper params and configuration
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
)
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]
for prefix in (
"hyperparams.",
"-hyperparams.",
"configuration.",
"-configuration.",
):
paths = [
_process_path(path) if path.startswith(prefix) else path for path in paths
]
return paths

View File

@ -5,6 +5,7 @@ from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import dpath
import pymongo.results
import six
from mongoengine import Q
@ -34,7 +35,9 @@ from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from timing_context import TimingContext
from utilities.dicts import deep_merge
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
from utilities.parameter_key_escaper import ParameterKeyEscaper
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
org_bll = OrgBLL()
@ -82,11 +85,7 @@ class TaskBLL(object):
@staticmethod
def get_by_id(
company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
@ -126,18 +125,14 @@ class TaskBLL(object):
allow_public=allow_public,
return_dicts=False,
)
res = None
if only:
res = q.only(*only)
elif return_tasks:
res = list(q)
q = q.only(*only)
count = len(res) if res is not None else q.count()
if count != len(ids):
if q.count() != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)
if return_tasks:
return res
return list(q)
@staticmethod
def create(call: APICall, fields: dict):
@ -179,20 +174,31 @@ class TaskBLL(object):
project: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
params_dict["execution"] = {}
for legacy_param in ("parameters", "configuration"):
legacy_value = execution_overrides.pop(legacy_param, None)
if legacy_value is not None:
params_dict["execution"] = legacy_value
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
params_prepare_for_save(params_dict, previous_task=task)
artifacts = execution_dict.get("artifacts")
if artifacts:
@ -220,6 +226,8 @@ class TaskBLL(object):
if task.output
else None,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
@ -625,28 +633,34 @@ class TaskBLL(object):
return [a.key for a in added], [a.key for a in updated]
@staticmethod
def get_aggregated_project_execution_parameters(
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[str]]:
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": company_id,
"execution.parameters": {"$exists": True, "$gt": {}},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
{"$unwind": "$parameters"},
{"$group": {"_id": "$parameters.k"}},
{"$sort": {"_id": 1}},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{
"$group": {
"_id": 1,
@ -672,7 +686,12 @@ class TaskBLL(object):
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r["_id"])
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))

View File

@ -3,7 +3,6 @@ from typing import TypeVar, Callable, Tuple, Sequence
import attr
import six
from boltons.dictutils import OneToOne
from apierrors import errors
from database.errors import translate_errors_context
@ -172,26 +171,3 @@ def split_by(
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
class ParameterKeyEscaper:
_mapping = OneToOne({".": "%2E", "$": "%24"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
return "%".join(map(cls._unescape, value.split("%%")))

View File

@ -49,13 +49,13 @@ class TaskSystemTags(object):
development = "development"
class Script(EmbeddedDocument):
class Script(EmbeddedDocument, ProperDictMixin):
binary = StringField(default="python")
repository = StringField(required=True)
repository = StringField(default="")
tag = StringField()
branch = StringField()
version_num = StringField()
entry_point = StringField(required=True)
entry_point = StringField(default="")
working_dir = StringField()
requirements = SafeDictField()
diff = StringField()
@ -84,6 +84,21 @@ class Artifact(EmbeddedDocument):
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class ParamsItem(EmbeddedDocument, ProperDictMixin):
section = StringField(required=True)
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
@ -116,9 +131,12 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True},
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
}
meta = {
@ -187,7 +205,7 @@ class Task(AttributedDocument):
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script: Script = EmbeddedDocumentField(Script)
script: Script = EmbeddedDocumentField(Script, default=Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
@ -196,3 +214,6 @@ class Task(AttributedDocument):
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)

View File

@ -0,0 +1,36 @@
from pymongo.database import Database, Collection
from bll.task.param_utils import (
hyperparams_legacy_type,
hyperparams_default_section,
split_param_name,
)
from tools import safe_get
def migrate_backend(db: Database):
hyperparam_fields = ("execution.parameters", "hyperparams")
configuration_fields = ("execution.model_desc", "configuration")
collection: Collection = db["task"]
for doc in collection.find(projection=hyperparam_fields + configuration_fields):
set_commands = {}
for (old_field, new_field), default_section in zip(
(hyperparam_fields, configuration_fields),
(hyperparams_default_section, None),
):
legacy = safe_get(doc, old_field, separator=".")
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_field, section, name)))
# if safe_get(doc, new_path) is not None:
# continue
new_value = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_value["section"] = section
set_commands[".".join(new_path)] = new_value
if set_commands:
collection.update_one({"_id": doc["_id"]}, {"$set": set_commands})

View File

@ -532,8 +532,8 @@ get_unique_metric_variants {
}
}
get_hyper_parameters {
"2.2" {
description: """Get a list of all hyper parameter names used in tasks within the given project."""
"2.9" {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
properties {
@ -557,9 +557,9 @@ get_hyper_parameters {
type: object
properties {
parameters {
description: "A list of hyper parameter names"
description: "A list of parameter sections and names"
type: array
items {type: string}
items {type: object}
}
remaining {
description: "Remaining results"

View File

@ -297,7 +297,80 @@ _definitions {
"$ref": "#/definitions/last_metrics_event"
}
}
params_item {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. The combination of section and name should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
configuration_item {
type: object
properties {
name {
description: "Name of the parameter. Should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
param_key {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. If the name is ommitted then the corresponding operation is performed on the whole section"
type: string
}
}
}
section_params {
description: "Task section params"
type: object
additionalProperties {
"$ref": "#/definitions/params_item"
}
}
replace_hyperparams_enum {
type: string
enum: [
none,
section,
all
]
}
task {
type: object
properties {
@ -418,9 +491,24 @@ _definitions {
"$ref": "#/definitions/last_metrics_variants"
}
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
}
}
}
get_by_id {
"2.1" {
description: "Gets task information"
@ -625,6 +713,20 @@ clone {
description: "The project of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_hyperparams {
description: "The hyper params for the new task. If not provided then taken from the original task"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
new_task_configuration {
description: "The configuration for the new task. If not provided then taken from the original task"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
execution_overrides {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
@ -698,6 +800,20 @@ create {
description: "Script info"
"$ref": "#/definitions/script"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
}
}
response {
@ -759,6 +875,20 @@ validate {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
@ -909,6 +1039,20 @@ edit {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
@ -1491,4 +1635,213 @@ make_private {
}
}
}
}
get_hyper_params {
"2.9": {
description: "Get the list of task hyper parameters"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
params {
type: object
description: "Hyper parameters (keyed by task ID)"
}
}
}
}
}
edit_hyper_params {
"2.9" {
description: "Add or update task hyper parameters"
request {
type: object
required: [ task, hyperparams ]
properties {
task {
description: "Task ID"
type: string
}
hyperparams {
description: "Task hyper parameters. The new ones will be added and the already existing ones will be updated"
type: array
items {"$ref": "#/definitions/params_item"}
}
replace_hyperparams {
description: """Can be set to one of the following:
'all' - all the hyper parameters will be replaced with the provided ones
'section' - the sections that present in the new parameters will be replaced with the provided parameters
'none' (the default value) - only the specific parameters will be updated or added"""
"$ref": "#/definitions/replace_hyperparams_enum"
}
}
}
response {
type: object
properties {
updated {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_hyper_params {
"2.9": {
description: "Delete task hyper parameters"
request {
type: object
required: [ task, hyperparams ]
properties {
task {
description: "Task ID"
type: string
}
hyperparams {
description: "List of hyper parameters to delete. In case a parameter with an empty name is passed all the section will be deleted"
type: array
items { "$ref": "#/definitions/param_key" }
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
get_configurations {
"2.9": {
description: "Get the list of task configurations"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
names {
description: "Names of the configuration items to retreive. If not passed or empty then all the configurations will be retreived."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
configurations {
type: object
description: "Configurations (keyed by task ID)"
}
}
}
}
}
get_configuration_names {
"2.9": {
description: "Get the list of task configuration items names"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
configurations {
type: object
description: "Names of task configuration items (keyed by task ID)"
}
}
}
}
}
edit_configuration {
"2.9" {
description: "Add or update task configuration"
request {
type: object
required: [ task, configuration ]
properties {
task {
description: "Task ID"
type: string
}
configuration {
description: "Task configuration items. The new ones will be added and the already existing ones will be updated"
type: array
items {"$ref": "#/definitions/configuration_item"}
}
replace_configuration {
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
type: boolean
}
}
}
response {
type: object
properties {
updated {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_configuration {
"2.9": {
description: "Delete task configuration items"
request {
type: object
required: [ task, configuration ]
properties {
task {
description: "Task ID"
type: string
}
configuration {
description: "List of configuration itemss to delete"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}

View File

@ -12,7 +12,6 @@ from apierrors.errors.bad_request import InvalidProjectId
from apimodels.base import UpdateResponse, MakePublicRequest
from apimodels.projects import (
GetHyperParamReq,
GetHyperParamResp,
ProjectReq,
ProjectTagsRequest,
)
@ -377,13 +376,12 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR
@endpoint(
"projects.get_hyper_parameters",
min_version="2.2",
min_version="2.9",
request_data_model=GetHyperParamReq,
response_data_model=GetHyperParamResp,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
page=request.page,

View File

@ -32,6 +32,13 @@ from apimodels.tasks import (
AddOrUpdateArtifactsResponse,
GetTypesRequest,
ResetRequest,
GetHyperParamsRequest,
EditHyperParamsRequest,
DeleteHyperParamsRequest,
GetConfigurationsRequest,
EditConfigurationRequest,
DeleteConfigurationRequest,
GetConfigurationNamesRequest,
)
from bll.event import EventBLL
from bll.organization import OrgBLL, Tags
@ -41,9 +48,14 @@ from bll.task import (
ChangeStatusRequest,
update_project_time,
split_by,
ParameterKeyEscaper,
)
from bll.task.hyperparams import HyperParams
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from bll.task.param_utils import (
params_prepare_for_save,
params_unprepare_from_saved,
escape_paths,
)
from bll.util import SetFieldsResolver
from database.errors import translate_errors_context
from database.model.model import Model
@ -57,9 +69,9 @@ from database.model.task.task import (
)
from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint
from service_repo.base import PartialVersion
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
from timing_context import TimingContext
from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
@ -120,30 +132,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
def escape_execution_parameters(call: APICall):
default_prefix = "execution.parameters."
def escape_paths(paths, prefix=default_prefix):
escaped_paths = []
for path in paths:
if path == prefix:
raise errors.bad_request.ValidationError(
"invalid task field", path=path
)
escaped_paths.append(
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
)
return escaped_paths
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
ordering = Task.get_ordering(call.data)
if ordering:
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
Task.set_ordering(call.data, escape_paths(ordering))
@endpoint("tasks.get_all_ex", required_fields=[])
@ -275,12 +270,15 @@ create_fields = {
"input": None,
"output_dest": None,
"execution": None,
"hyperparams": None,
"configuration": None,
"script": None,
}
def prepare_for_save(call: APICall, fields: dict):
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
conform_tag_fields(call, fields, validate=True)
params_prepare_for_save(fields, previous_task=previous_task)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
@ -293,12 +291,6 @@ def prepare_for_save(call: APICall, fields: dict):
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
@ -308,18 +300,15 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
conform_output_tags(call, tasks_data)
for task_data in tasks_data:
parameters = safe_get(task_data, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
}
dpath.set(task_data, "execution/parameters", parameters)
for data in tasks_data:
params_unprepare_from_saved(
fields=data,
copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"),
)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
):
valid_fields = valid_fields if valid_fields is not None else create_fields
t_fields = task_fields
@ -337,7 +326,7 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
return prepare_for_save(call, fields)
return prepare_for_save(call, fields, previous_task=previous_task)
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
@ -401,6 +390,8 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
project=request.new_task_project,
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
hyperparams=request.new_hyperparams,
configuration=request.new_configuration,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
)
@ -598,6 +589,100 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(updated=0)
@endpoint(
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
call.result.data = {
"params": [{"task": task, **data} for task, data in tasks_params.items()]
}
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
)
}
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_params(
company_id, task_id=request.task, hyperparams=request.hyperparams
)
}
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
call.result.data = {
"configurations": [
{"task": task, **data} for task, data in tasks_params.items()
]
}
@endpoint(
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
)
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks
)
call.result.data = {
"configurations": [
{"task": task, **data} for task, data in tasks_params.items()
]
}
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
)
}
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
def delete_configuration(
call: APICall, company_id, request: DeleteConfigurationRequest
):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id, task_id=request.task, configuration=request.configuration
)
}
@endpoint(
"tasks.enqueue",
request_data_model=EnqueueRequest,

View File

@ -5,7 +5,6 @@ log = config.logger(__file__)
class TestTasksDiff(TestService):
def setUp(self, version="2.0"):
super(TestTasksDiff, self).setUp(version=version)
@ -17,7 +16,14 @@ class TestTasksDiff(TestService):
def _compare_script(self, task_id, script):
task = self.api.tasks.get_by_id(task=task_id).task
if not script:
self.assertFalse(task.get("script", None))
self.assertTrue(
task.get(
"script",
dict(
binary="python", repository="", entry_point="", requirements={}
),
)
)
else:
for key, value in script.items():
self.assertEqual(task.script[key], value)

View File

@ -114,7 +114,7 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.status, "created")
self.assertEqual(new_task.script, script)
self.assertEqual(new_task.parent, task)
self.assertEqual(new_task.execution.parameters, execution["parameters"])
# self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.system_tags, [])

View File

@ -0,0 +1,46 @@
from boltons.dictutils import OneToOne
from apierrors import errors
class ParameterKeyEscaper:
"""
Makes the fields name ready for use with MongoDB and Mongoengine
. and $ are replaced with their codes
__ and leading _ are escaped
Since % is used as an escape character the % is also escaped
"""
_mapping = OneToOne({".": "%2E", "$": "%24", "__": "%_%_"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
if value is None:
raise errors.bad_request.ValidationError("Key cannot be empty")
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
if value.startswith("_"):
value = "%_" + value[1:]
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
value = "%".join(map(cls._unescape, value.split("%%")))
if value.startswith("%_"):
value = "_" + value[2:]
return value