Allow using "$", "." and whitespaces in hyper-parameter keys

This commit is contained in:
allegroai 2020-01-02 15:28:50 +02:00
parent 7d10bbdf8e
commit dedac3b2fe
7 changed files with 167 additions and 82 deletions

View File

@ -5,12 +5,12 @@ from typing import Union, Type, Iterable
import jsonmodels.errors import jsonmodels.errors
import six import six
import validators
from jsonmodels import fields from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.models import Base as ModelBase from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator from jsonmodels.validators import Enum as EnumValidator
from luqum.parser import parser, ParseError from luqum.parser import parser, ParseError
from validators import email as email_validator, domain as domain_validator
from apierrors import errors from apierrors import errors
@ -66,9 +66,7 @@ class DictField(fields.BaseField):
value_types = tuple() value_types = tuple()
return tuple( return tuple(
_LazyType(type_) _LazyType(type_) if isinstance(type_, six.string_types) else type_
if isinstance(type_, six.string_types)
else type_
for type_ in value_types for type_ in value_types
) )
@ -107,7 +105,7 @@ class IntField(fields.IntField):
def validate_lucene_query(value): def validate_lucene_query(value):
if value == '': if value == "":
return return
try: try:
parser.parse(value) parser.parse(value)
@ -125,6 +123,7 @@ class LuceneQueryField(fields.StringField):
class NullableEnumValidator(EnumValidator): class NullableEnumValidator(EnumValidator):
"""Validator for enums that allows a None value.""" """Validator for enums that allows a None value."""
def validate(self, value): def validate(self, value):
if value is not None: if value is not None:
super(NullableEnumValidator, self).validate(value) super(NullableEnumValidator, self).validate(value)
@ -153,10 +152,6 @@ class EnumField(fields.StringField):
class ActualEnumField(fields.StringField): class ActualEnumField(fields.StringField):
@property
def types(self):
return self.__enum,
def __init__( def __init__(
self, self,
enum_class: Type[Enum], enum_class: Type[Enum],
@ -167,6 +162,7 @@ class ActualEnumField(fields.StringField):
**kwargs **kwargs
): ):
self.__enum = enum_class self.__enum = enum_class
self.types = (enum_class,)
# noinspection PyTypeChecker # noinspection PyTypeChecker
choices = list(enum_class) choices = list(enum_class)
validator_cls = EnumValidator if required else NullableEnumValidator validator_cls = EnumValidator if required else NullableEnumValidator
@ -197,7 +193,7 @@ class EmailField(fields.StringField):
super().validate(value) super().validate(value)
if value is None: if value is None:
return return
if validators.email(value) is not True: if email_validator(value) is not True:
raise errors.bad_request.InvalidEmailAddress() raise errors.bad_request.InvalidEmailAddress()
@ -206,7 +202,7 @@ class DomainField(fields.StringField):
super().validate(value) super().validate(value)
if value is None: if value is None:
return return
if validators.domain(value) is not True: if domain_validator(value) is not True:
raise errors.bad_request.InvalidDomainName() raise errors.bad_request.InvalidDomainName()

View File

@ -4,4 +4,5 @@ from .utils import (
update_project_time, update_project_time,
validate_status_change, validate_status_change,
split_by, split_by,
ParameterKeyEscaper,
) )

View File

@ -1,4 +1,3 @@
import re
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from operator import attrgetter from operator import attrgetter
@ -32,8 +31,7 @@ from service_repo import APICall
from timing_context import TimingContext from timing_context import TimingContext
from utilities.dicts import deep_merge from utilities.dicts import deep_merge
from utilities.threads_manager import ThreadsManager from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
log = config.logger(__file__) log = config.logger(__file__)
@ -171,6 +169,11 @@ class TaskBLL(object):
task = cls.get_by_id(company_id=company_id, task_id=task_id) task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_dict = task.execution.to_proper_dict() if task.execution else {}
if execution_overrides: if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides) execution_dict = deep_merge(execution_dict, execution_overrides)
artifacts = execution_dict.get("artifacts") artifacts = execution_dict.get("artifacts")
if artifacts: if artifacts:
@ -178,25 +181,28 @@ class TaskBLL(object):
a for a in artifacts if a.get("mode") != ArtifactModes.output a for a in artifacts if a.get("mode") != ArtifactModes.output
] ]
now = datetime.utcnow() now = datetime.utcnow()
new_task = Task(
id=create_id(), with translate_errors_context():
user=user_id, new_task = Task(
company=company_id, id=create_id(),
created=now, user=user_id,
last_update=now, company=company_id,
name=name or task.name, created=now,
comment=comment or task.comment, last_update=now,
parent=parent or task.parent, name=name or task.name,
project=project or task.project, comment=comment or task.comment,
tags=tags or task.tags, parent=parent or task.parent,
system_tags=system_tags or [], project=project or task.project,
type=task.type, tags=tags or task.tags,
script=task.script, system_tags=system_tags or [],
output=Output(destination=task.output.destination) if task.output else None, type=task.type,
execution=execution_dict, script=task.script,
) output=Output(destination=task.output.destination) if task.output else None,
cls.validate(new_task) execution=execution_dict,
new_task.save() )
cls.validate(new_task)
new_task.save()
return new_task return new_task
@classmethod @classmethod
@ -215,18 +221,6 @@ class TaskBLL(object):
cls.validate_execution_model(task) cls.validate_execution_model(task)
if task.execution:
if task.execution.parameters:
cls._validate_execution_parameters(task.execution.parameters)
@staticmethod
def _validate_execution_parameters(parameters):
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
if invalid_keys:
raise errors.bad_request.ValidationError(
"execution.parameters keys contain whitespace", keys=invalid_keys
)
@staticmethod @staticmethod
def get_unique_metric_variants(company_id, project_ids=None): def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [ pipeline = [
@ -658,7 +652,10 @@ class TaskBLL(object):
if result: if result:
total = int(result.get("total", -1)) total = int(result.get("total", -1))
results = [r["_id"] for r in result.get("results", [])] results = [
ParameterKeyEscaper.unescape(r["_id"])
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size)) remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results return total, remaining, results

View File

@ -3,6 +3,7 @@ from typing import TypeVar, Callable, Tuple, Sequence
import attr import attr
import six import six
from boltons.dictutils import OneToOne
from apierrors import errors from apierrors import errors
from database.errors import translate_errors_context from database.errors import translate_errors_context
@ -171,3 +172,26 @@ def split_by(
[item for cond, item in applied if cond], [item for cond, item in applied if cond],
[item for cond, item in applied if not cond], [item for cond, item in applied if not cond],
) )
class ParameterKeyEscaper:
_mapping = OneToOne({".": "%2E", "$": "%24"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
return "%".join(map(cls._unescape, value.split("%%")))

View File

@ -1,7 +1,7 @@
import re import re
from collections import namedtuple from collections import namedtuple
from functools import reduce from functools import reduce
from typing import Collection, Sequence, Union from typing import Collection, Sequence, Union, Optional
from boltons.iterutils import first from boltons.iterutils import first
from dateutil.parser import parse as parse_datetime from dateutil.parser import parse as parse_datetime
@ -60,7 +60,7 @@ class ProperDictMixin(object):
class GetMixin(PropsMixin): class GetMixin(PropsMixin):
_text_score = "$text_score" _text_score = "$text_score"
_projection_key = "projection"
_ordering_key = "order_by" _ordering_key = "order_by"
_search_text_key = "search_text" _search_text_key = "search_text"
@ -270,11 +270,26 @@ class GetMixin(PropsMixin):
return override_projection return override_projection
if not parameters: if not parameters:
return [] return []
return parameters.get("projection") or parameters.get("only_fields", []) return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod @classmethod
def set_default_ordering(cls, parameters, value): def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value parameters.pop("only_fields", None)
parameters[cls._projection_key] = value
return value
@classmethod
def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]:
return parameters.get(cls._ordering_key)
@classmethod
def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters[cls._ordering_key] = value
return value
@classmethod
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod @classmethod
def get_many_with_join( def get_many_with_join(

View File

@ -1,12 +1,11 @@
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from operator import attrgetter from operator import attrgetter
from typing import Sequence, Callable, Type, TypeVar from typing import Sequence, Callable, Type, TypeVar, Union
import attr import attr
import dpath import dpath
import mongoengine import mongoengine
import six
from mongoengine import EmbeddedDocument, Q from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne from pymongo import UpdateOne
@ -33,7 +32,13 @@ from apimodels.tasks import (
) )
from bll.event import EventBLL from bll.event import EventBLL
from bll.queue import QueueBLL from bll.queue import QueueBLL
from bll.task import TaskBLL, ChangeStatusRequest, update_project_time, split_by from bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
split_by,
ParameterKeyEscaper,
)
from bll.util import SetFieldsResolver from bll.util import SetFieldsResolver
from database.errors import translate_errors_context from database.errors import translate_errors_context
from database.model.model import Model from database.model.model import Model
@ -97,13 +102,37 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
req_model.task, company_id=company_id, allow_public=True req_model.task, company_id=company_id, allow_public=True
) )
task_dict = task.to_proper_dict() task_dict = task.to_proper_dict()
conform_output_tags(call, task_dict) unprepare_from_saved(call, task_dict)
call.result.data = {"task": task_dict} call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall):
default_prefix = "execution.parameters."
def escape_paths(paths, prefix=default_prefix):
return [
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
for path in paths
]
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
ordering = Task.get_ordering(call.data)
if ordering:
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
@endpoint("tasks.get_all_ex", required_fields=[]) @endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall): def get_all_ex(call: APICall):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
escape_execution_parameters(call)
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"): with TimingContext("mongo", "task_get_all_ex"):
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
@ -112,13 +141,16 @@ def get_all_ex(call: APICall):
query_options=get_all_query_options, query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True, # required in case projection is requested for public dataset/versions
) )
conform_output_tags(call, tasks) unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks} call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[]) @endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall): def get_all(call: APICall):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
escape_execution_parameters(call)
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "task_get_all"): with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many( tasks = Task.get_many(
@ -128,7 +160,7 @@ def get_all(call: APICall):
query_options=get_all_query_options, query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True, # required in case projection is requested for public dataset/versions
) )
conform_output_tags(call, tasks) unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks} call.result.data = {"tasks": tasks}
@ -223,6 +255,45 @@ create_fields = {
} }
def prepare_for_save(call: APICall, fields: dict):
conform_tag_fields(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
try:
path = f"script/{field}"
value = dpath.get(fields, path)
if isinstance(value, str):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
conform_output_tags(call, tasks_data)
for task_data in tasks_data:
parameters = safe_get(task_data, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
}
dpath.set(task_data, "execution/parameters", parameters)
def prepare_create_fields( def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None call: APICall, valid_fields=None, output=None, previous_task: Task = None
): ):
@ -242,25 +313,7 @@ def prepare_create_fields(
output = Output(destination=output_dest) output = Output(destination=output_dest)
fields["output"] = output fields["output"] = output
conform_tag_fields(call, fields) return prepare_for_save(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
try:
path = "script/%s" % field
value = dpath.get(fields, path)
if isinstance(value, six.string_types):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
parameters = {k.strip(): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
def _validate_and_get_task_from_call(call: APICall, **kwargs): def _validate_and_get_task_from_call(call: APICall, **kwargs):
@ -320,8 +373,7 @@ def prepare_update_fields(call: APICall, task, call_data):
t_fields = task_fields t_fields = task_fields
t_fields.add("output__error") t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields) fields = parse_from_call(call_data, update_fields, t_fields)
conform_tag_fields(call, fields) return prepare_for_save(call, fields), valid_fields
return fields, valid_fields
@endpoint( @endpoint(
@ -348,7 +400,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
) )
update_project_time(updated_fields.get("project")) update_project_time(updated_fields.get("project"))
conform_output_tags(call, updated_fields) unprepare_from_saved(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
@ -473,7 +525,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fixed_fields.update(last_update=now) fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields) updated = task.update(upsert=False, **fixed_fields)
update_project_time(fields.get("project")) update_project_time(fields.get("project"))
conform_output_tags(call, fields) unprepare_from_saved(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else: else:
call.result.data_model = UpdateResponse(updated=0) call.result.data_model = UpdateResponse(updated=0)

View File

@ -52,7 +52,7 @@ class TestEntityOrdering(TestService):
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence: def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
return self.api.tasks.get_all_ex( return self.api.tasks.get_all_ex(
only_fields=self.only_fields, only_fields=self.only_fields,
order_by=order_by, order_by=[order_by] if order_by else None,
comment=self.test_comment, comment=self.test_comment,
page=page, page=page,
page_size=page_size, page_size=page_size,
@ -79,7 +79,7 @@ class TestEntityOrdering(TestService):
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs): def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
tasks = self.api.tasks.get_all_ex( tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields, only_fields=self.only_fields,
order_by=order_by, order_by=[order_by] if order_by else None,
comment=self.test_comment, comment=self.test_comment,
**kwargs, **kwargs,
).tasks ).tasks