Allow using "$", "." and whitespaces in hyper-parameter keys

This commit is contained in:
allegroai 2020-01-02 15:28:50 +02:00
parent 7d10bbdf8e
commit dedac3b2fe
7 changed files with 167 additions and 82 deletions

View File

@ -5,12 +5,12 @@ from typing import Union, Type, Iterable
import jsonmodels.errors
import six
import validators
from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from luqum.parser import parser, ParseError
from validators import email as email_validator, domain as domain_validator
from apierrors import errors
@ -66,9 +66,7 @@ class DictField(fields.BaseField):
value_types = tuple()
return tuple(
_LazyType(type_)
if isinstance(type_, six.string_types)
else type_
_LazyType(type_) if isinstance(type_, six.string_types) else type_
for type_ in value_types
)
@ -107,7 +105,7 @@ class IntField(fields.IntField):
def validate_lucene_query(value):
if value == '':
if value == "":
return
try:
parser.parse(value)
@ -125,6 +123,7 @@ class LuceneQueryField(fields.StringField):
class NullableEnumValidator(EnumValidator):
"""Validator for enums that allows a None value."""
def validate(self, value):
if value is not None:
super(NullableEnumValidator, self).validate(value)
@ -153,10 +152,6 @@ class EnumField(fields.StringField):
class ActualEnumField(fields.StringField):
@property
def types(self):
return self.__enum,
def __init__(
self,
enum_class: Type[Enum],
@ -167,6 +162,7 @@ class ActualEnumField(fields.StringField):
**kwargs
):
self.__enum = enum_class
self.types = (enum_class,)
# noinspection PyTypeChecker
choices = list(enum_class)
validator_cls = EnumValidator if required else NullableEnumValidator
@ -197,7 +193,7 @@ class EmailField(fields.StringField):
super().validate(value)
if value is None:
return
if validators.email(value) is not True:
if email_validator(value) is not True:
raise errors.bad_request.InvalidEmailAddress()
@ -206,7 +202,7 @@ class DomainField(fields.StringField):
super().validate(value)
if value is None:
return
if validators.domain(value) is not True:
if domain_validator(value) is not True:
raise errors.bad_request.InvalidDomainName()

View File

@ -4,4 +4,5 @@ from .utils import (
update_project_time,
validate_status_change,
split_by,
ParameterKeyEscaper,
)

View File

@ -1,4 +1,3 @@
import re
from collections import OrderedDict
from datetime import datetime, timedelta
from operator import attrgetter
@ -32,8 +31,7 @@ from service_repo import APICall
from timing_context import TimingContext
from utilities.dicts import deep_merge
from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
log = config.logger(__file__)
@ -171,6 +169,11 @@ class TaskBLL(object):
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides)
artifacts = execution_dict.get("artifacts")
if artifacts:
@ -178,25 +181,28 @@ class TaskBLL(object):
a for a in artifacts if a.get("mode") != ArtifactModes.output
]
now = datetime.utcnow()
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
execution=execution_dict,
)
cls.validate(new_task)
new_task.save()
with translate_errors_context():
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
execution=execution_dict,
)
cls.validate(new_task)
new_task.save()
return new_task
@classmethod
@ -215,18 +221,6 @@ class TaskBLL(object):
cls.validate_execution_model(task)
if task.execution:
if task.execution.parameters:
cls._validate_execution_parameters(task.execution.parameters)
@staticmethod
def _validate_execution_parameters(parameters):
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
if invalid_keys:
raise errors.bad_request.ValidationError(
"execution.parameters keys contain whitespace", keys=invalid_keys
)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [
@ -658,7 +652,10 @@ class TaskBLL(object):
if result:
total = int(result.get("total", -1))
results = [r["_id"] for r in result.get("results", [])]
results = [
ParameterKeyEscaper.unescape(r["_id"])
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@ -3,6 +3,7 @@ from typing import TypeVar, Callable, Tuple, Sequence
import attr
import six
from boltons.dictutils import OneToOne
from apierrors import errors
from database.errors import translate_errors_context
@ -171,3 +172,26 @@ def split_by(
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
class ParameterKeyEscaper:
_mapping = OneToOne({".": "%2E", "$": "%24"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
return "%".join(map(cls._unescape, value.split("%%")))

View File

@ -1,7 +1,7 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union
from typing import Collection, Sequence, Union, Optional
from boltons.iterutils import first
from dateutil.parser import parse as parse_datetime
@ -60,7 +60,7 @@ class ProperDictMixin(object):
class GetMixin(PropsMixin):
_text_score = "$text_score"
_projection_key = "projection"
_ordering_key = "order_by"
_search_text_key = "search_text"
@ -270,11 +270,26 @@ class GetMixin(PropsMixin):
return override_projection
if not parameters:
return []
return parameters.get("projection") or parameters.get("only_fields", [])
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod
def set_default_ordering(cls, parameters, value):
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters.pop("only_fields", None)
parameters[cls._projection_key] = value
return value
@classmethod
def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]:
return parameters.get(cls._ordering_key)
@classmethod
def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters[cls._ordering_key] = value
return value
@classmethod
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod
def get_many_with_join(

View File

@ -1,12 +1,11 @@
from copy import deepcopy
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Callable, Type, TypeVar
from typing import Sequence, Callable, Type, TypeVar, Union
import attr
import dpath
import mongoengine
import six
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
@ -33,7 +32,13 @@ from apimodels.tasks import (
)
from bll.event import EventBLL
from bll.queue import QueueBLL
from bll.task import TaskBLL, ChangeStatusRequest, update_project_time, split_by
from bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
split_by,
ParameterKeyEscaper,
)
from bll.util import SetFieldsResolver
from database.errors import translate_errors_context
from database.model.model import Model
@ -97,13 +102,37 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
req_model.task, company_id=company_id, allow_public=True
)
task_dict = task.to_proper_dict()
conform_output_tags(call, task_dict)
unprepare_from_saved(call, task_dict)
call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall):
default_prefix = "execution.parameters."
def escape_paths(paths, prefix=default_prefix):
return [
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
for path in paths
]
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
ordering = Task.get_ordering(call.data)
if ordering:
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
tasks = Task.get_many_with_join(
@ -112,13 +141,16 @@ def get_all_ex(call: APICall):
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
conform_output_tags(call, tasks)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
@ -128,7 +160,7 @@ def get_all(call: APICall):
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
conform_output_tags(call, tasks)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@ -223,6 +255,45 @@ create_fields = {
}
def prepare_for_save(call: APICall, fields: dict):
conform_tag_fields(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
try:
path = f"script/{field}"
value = dpath.get(fields, path)
if isinstance(value, str):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
conform_output_tags(call, tasks_data)
for task_data in tasks_data:
parameters = safe_get(task_data, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
}
dpath.set(task_data, "execution/parameters", parameters)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None
):
@ -242,25 +313,7 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
conform_tag_fields(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
try:
path = "script/%s" % field
value = dpath.get(fields, path)
if isinstance(value, six.string_types):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
parameters = {k.strip(): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
return prepare_for_save(call, fields)
def _validate_and_get_task_from_call(call: APICall, **kwargs):
@ -320,8 +373,7 @@ def prepare_update_fields(call: APICall, task, call_data):
t_fields = task_fields
t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields)
conform_tag_fields(call, fields)
return fields, valid_fields
return prepare_for_save(call, fields), valid_fields
@endpoint(
@ -348,7 +400,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
)
update_project_time(updated_fields.get("project"))
conform_output_tags(call, updated_fields)
unprepare_from_saved(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@ -473,7 +525,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields)
update_project_time(fields.get("project"))
conform_output_tags(call, fields)
unprepare_from_saved(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)

View File

@ -52,7 +52,7 @@ class TestEntityOrdering(TestService):
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
return self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=order_by,
order_by=[order_by] if order_by else None,
comment=self.test_comment,
page=page,
page_size=page_size,
@ -79,7 +79,7 @@ class TestEntityOrdering(TestService):
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=order_by,
order_by=[order_by] if order_by else None,
comment=self.test_comment,
**kwargs,
).tasks