diff --git a/server/apimodels/__init__.py b/server/apimodels/__init__.py index eb295cc..b8447ed 100644 --- a/server/apimodels/__init__.py +++ b/server/apimodels/__init__.py @@ -5,12 +5,12 @@ from typing import Union, Type, Iterable import jsonmodels.errors import six -import validators from jsonmodels import fields from jsonmodels.fields import _LazyType, NotSet from jsonmodels.models import Base as ModelBase from jsonmodels.validators import Enum as EnumValidator from luqum.parser import parser, ParseError +from validators import email as email_validator, domain as domain_validator from apierrors import errors @@ -66,9 +66,7 @@ class DictField(fields.BaseField): value_types = tuple() return tuple( - _LazyType(type_) - if isinstance(type_, six.string_types) - else type_ + _LazyType(type_) if isinstance(type_, six.string_types) else type_ for type_ in value_types ) @@ -107,7 +105,7 @@ class IntField(fields.IntField): def validate_lucene_query(value): - if value == '': + if value == "": return try: parser.parse(value) @@ -125,6 +123,7 @@ class LuceneQueryField(fields.StringField): class NullableEnumValidator(EnumValidator): """Validator for enums that allows a None value.""" + def validate(self, value): if value is not None: super(NullableEnumValidator, self).validate(value) @@ -153,10 +152,6 @@ class EnumField(fields.StringField): class ActualEnumField(fields.StringField): - @property - def types(self): - return self.__enum, - def __init__( self, enum_class: Type[Enum], @@ -167,6 +162,7 @@ class ActualEnumField(fields.StringField): **kwargs ): self.__enum = enum_class + self.types = (enum_class,) # noinspection PyTypeChecker choices = list(enum_class) validator_cls = EnumValidator if required else NullableEnumValidator @@ -197,7 +193,7 @@ class EmailField(fields.StringField): super().validate(value) if value is None: return - if validators.email(value) is not True: + if email_validator(value) is not True: raise errors.bad_request.InvalidEmailAddress() @@ -206,7 +202,7 @@ class DomainField(fields.StringField): super().validate(value) if value is None: return - if validators.domain(value) is not True: + if domain_validator(value) is not True: raise errors.bad_request.InvalidDomainName() diff --git a/server/bll/task/__init__.py b/server/bll/task/__init__.py index 544b289..fcfa038 100644 --- a/server/bll/task/__init__.py +++ b/server/bll/task/__init__.py @@ -4,4 +4,5 @@ from .utils import ( update_project_time, validate_status_change, split_by, + ParameterKeyEscaper, ) diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index 90ac870..2df6f8a 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -1,4 +1,3 @@ -import re from collections import OrderedDict from datetime import datetime, timedelta from operator import attrgetter @@ -32,8 +31,7 @@ from service_repo import APICall from timing_context import TimingContext from utilities.dicts import deep_merge from utilities.threads_manager import ThreadsManager -from .utils import ChangeStatusRequest, validate_status_change - +from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper log = config.logger(__file__) @@ -171,6 +169,11 @@ class TaskBLL(object): task = cls.get_by_id(company_id=company_id, task_id=task_id) execution_dict = task.execution.to_proper_dict() if task.execution else {} if execution_overrides: + parameters = execution_overrides.get("parameters") + if parameters is not None: + execution_overrides["parameters"] = { + ParameterKeyEscaper.escape(k): v for k, v in parameters.items() + } execution_dict = deep_merge(execution_dict, execution_overrides) artifacts = execution_dict.get("artifacts") if artifacts: @@ -178,25 +181,28 @@ class TaskBLL(object): a for a in artifacts if a.get("mode") != ArtifactModes.output ] now = datetime.utcnow() - new_task = Task( - id=create_id(), - user=user_id, - company=company_id, - created=now, - last_update=now, - name=name or task.name, - comment=comment or task.comment, - parent=parent or task.parent, - project=project or task.project, - tags=tags or task.tags, - system_tags=system_tags or [], - type=task.type, - script=task.script, - output=Output(destination=task.output.destination) if task.output else None, - execution=execution_dict, - ) - cls.validate(new_task) - new_task.save() + + with translate_errors_context(): + new_task = Task( + id=create_id(), + user=user_id, + company=company_id, + created=now, + last_update=now, + name=name or task.name, + comment=comment or task.comment, + parent=parent or task.parent, + project=project or task.project, + tags=tags or task.tags, + system_tags=system_tags or [], + type=task.type, + script=task.script, + output=Output(destination=task.output.destination) if task.output else None, + execution=execution_dict, + ) + cls.validate(new_task) + new_task.save() + return new_task @classmethod @@ -215,18 +221,6 @@ class TaskBLL(object): cls.validate_execution_model(task) - if task.execution: - if task.execution.parameters: - cls._validate_execution_parameters(task.execution.parameters) - - @staticmethod - def _validate_execution_parameters(parameters): - invalid_keys = [k for k in parameters if re.search(r"\s", k)] - if invalid_keys: - raise errors.bad_request.ValidationError( - "execution.parameters keys contain whitespace", keys=invalid_keys - ) - @staticmethod def get_unique_metric_variants(company_id, project_ids=None): pipeline = [ @@ -658,7 +652,10 @@ class TaskBLL(object): if result: total = int(result.get("total", -1)) - results = [r["_id"] for r in result.get("results", [])] + results = [ + ParameterKeyEscaper.unescape(r["_id"]) + for r in result.get("results", []) + ] remaining = max(0, total - (len(results) + page * page_size)) return total, remaining, results diff --git a/server/bll/task/utils.py b/server/bll/task/utils.py index c0580be..ab74afa 100644 --- a/server/bll/task/utils.py +++ b/server/bll/task/utils.py @@ -3,6 +3,7 @@ from typing import TypeVar, Callable, Tuple, Sequence import attr import six +from boltons.dictutils import OneToOne from apierrors import errors from database.errors import translate_errors_context @@ -171,3 +172,26 @@ def split_by( [item for cond, item in applied if cond], [item for cond, item in applied if not cond], ) + + +class ParameterKeyEscaper: + _mapping = OneToOne({".": "%2E", "$": "%24"}) + + @classmethod + def escape(cls, value): + """ Quote a parameter key """ + value = value.strip().replace("%", "%%") + for c, r in cls._mapping.items(): + value = value.replace(c, r) + return value + + @classmethod + def _unescape(cls, value): + for c, r in cls._mapping.inv.items(): + value = value.replace(c, r) + return value + + @classmethod + def unescape(cls, value): + """ Unquote a quoted parameter key """ + return "%".join(map(cls._unescape, value.split("%%"))) diff --git a/server/database/model/base.py b/server/database/model/base.py index e916dbf..272f8ef 100644 --- a/server/database/model/base.py +++ b/server/database/model/base.py @@ -1,7 +1,7 @@ import re from collections import namedtuple from functools import reduce -from typing import Collection, Sequence, Union +from typing import Collection, Sequence, Union, Optional from boltons.iterutils import first from dateutil.parser import parse as parse_datetime @@ -60,7 +60,7 @@ class ProperDictMixin(object): class GetMixin(PropsMixin): _text_score = "$text_score" - + _projection_key = "projection" _ordering_key = "order_by" _search_text_key = "search_text" @@ -270,11 +270,26 @@ class GetMixin(PropsMixin): return override_projection if not parameters: return [] - return parameters.get("projection") or parameters.get("only_fields", []) + return parameters.get(cls._projection_key) or parameters.get("only_fields", []) @classmethod - def set_default_ordering(cls, parameters, value): - parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value + def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]: + parameters.pop("only_fields", None) + parameters[cls._projection_key] = value + return value + + @classmethod + def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]: + return parameters.get(cls._ordering_key) + + @classmethod + def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]: + parameters[cls._ordering_key] = value + return value + + @classmethod + def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None: + cls.set_ordering(parameters, cls.get_ordering(parameters) or value) @classmethod def get_many_with_join( diff --git a/server/services/tasks.py b/server/services/tasks.py index e93db64..2a9984d 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -1,12 +1,11 @@ from copy import deepcopy from datetime import datetime from operator import attrgetter -from typing import Sequence, Callable, Type, TypeVar +from typing import Sequence, Callable, Type, TypeVar, Union import attr import dpath import mongoengine -import six from mongoengine import EmbeddedDocument, Q from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne @@ -33,7 +32,13 @@ from apimodels.tasks import ( ) from bll.event import EventBLL from bll.queue import QueueBLL -from bll.task import TaskBLL, ChangeStatusRequest, update_project_time, split_by +from bll.task import ( + TaskBLL, + ChangeStatusRequest, + update_project_time, + split_by, + ParameterKeyEscaper, +) from bll.util import SetFieldsResolver from database.errors import translate_errors_context from database.model.model import Model @@ -97,13 +102,37 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest): req_model.task, company_id=company_id, allow_public=True ) task_dict = task.to_proper_dict() - conform_output_tags(call, task_dict) + unprepare_from_saved(call, task_dict) call.result.data = {"task": task_dict} +def escape_execution_parameters(call: APICall): + default_prefix = "execution.parameters." + + def escape_paths(paths, prefix=default_prefix): + return [ + prefix + ParameterKeyEscaper.escape(path[len(prefix) :]) + if path.startswith(prefix) + else path + for path in paths + ] + + projection = Task.get_projection(call.data) + if projection: + Task.set_projection(call.data, escape_paths(projection)) + + ordering = Task.get_ordering(call.data) + if ordering: + ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix)) + Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix)) + + @endpoint("tasks.get_all_ex", required_fields=[]) def get_all_ex(call: APICall): conform_tag_fields(call, call.data) + + escape_execution_parameters(call) + with translate_errors_context(): with TimingContext("mongo", "task_get_all_ex"): tasks = Task.get_many_with_join( @@ -112,13 +141,16 @@ def get_all_ex(call: APICall): query_options=get_all_query_options, allow_public=True, # required in case projection is requested for public dataset/versions ) - conform_output_tags(call, tasks) + unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @endpoint("tasks.get_all", required_fields=[]) def get_all(call: APICall): conform_tag_fields(call, call.data) + + escape_execution_parameters(call) + with translate_errors_context(): with TimingContext("mongo", "task_get_all"): tasks = Task.get_many( @@ -128,7 +160,7 @@ def get_all(call: APICall): query_options=get_all_query_options, allow_public=True, # required in case projection is requested for public dataset/versions ) - conform_output_tags(call, tasks) + unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @@ -223,6 +255,45 @@ create_fields = { } +def prepare_for_save(call: APICall, fields: dict): + conform_tag_fields(call, fields) + + # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths + for field in task_script_fields: + try: + path = f"script/{field}" + value = dpath.get(fields, path) + if isinstance(value, str): + value = value.strip() + dpath.set(fields, path, value) + except KeyError: + pass + + parameters = safe_get(fields, "execution/parameters") + if parameters is not None: + # Escape keys to make them mongo-safe + parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()} + dpath.set(fields, "execution/parameters", parameters) + + return fields + + +def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]): + if isinstance(tasks_data, dict): + tasks_data = [tasks_data] + + conform_output_tags(call, tasks_data) + + for task_data in tasks_data: + parameters = safe_get(task_data, "execution/parameters") + if parameters is not None: + # Escape keys to make them mongo-safe + parameters = { + ParameterKeyEscaper.unescape(k): v for k, v in parameters.items() + } + dpath.set(task_data, "execution/parameters", parameters) + + def prepare_create_fields( call: APICall, valid_fields=None, output=None, previous_task: Task = None ): @@ -242,25 +313,7 @@ def prepare_create_fields( output = Output(destination=output_dest) fields["output"] = output - conform_tag_fields(call, fields) - - # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths - for field in task_script_fields: - try: - path = "script/%s" % field - value = dpath.get(fields, path) - if isinstance(value, six.string_types): - value = value.strip() - dpath.set(fields, path, value) - except KeyError: - pass - - parameters = safe_get(fields, "execution/parameters") - if parameters is not None: - parameters = {k.strip(): v for k, v in parameters.items()} - dpath.set(fields, "execution/parameters", parameters) - - return fields + return prepare_for_save(call, fields) def _validate_and_get_task_from_call(call: APICall, **kwargs): @@ -320,8 +373,7 @@ def prepare_update_fields(call: APICall, task, call_data): t_fields = task_fields t_fields.add("output__error") fields = parse_from_call(call_data, update_fields, t_fields) - conform_tag_fields(call, fields) - return fields, valid_fields + return prepare_for_save(call, fields), valid_fields @endpoint( @@ -348,7 +400,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest): ) update_project_time(updated_fields.get("project")) - conform_output_tags(call, updated_fields) + unprepare_from_saved(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields) @@ -473,7 +525,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): fixed_fields.update(last_update=now) updated = task.update(upsert=False, **fixed_fields) update_project_time(fields.get("project")) - conform_output_tags(call, fields) + unprepare_from_saved(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0) diff --git a/server/tests/automated/test_entity_ordering.py b/server/tests/automated/test_entity_ordering.py index ffc7573..5858ac9 100644 --- a/server/tests/automated/test_entity_ordering.py +++ b/server/tests/automated/test_entity_ordering.py @@ -52,7 +52,7 @@ class TestEntityOrdering(TestService): def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence: return self.api.tasks.get_all_ex( only_fields=self.only_fields, - order_by=order_by, + order_by=[order_by] if order_by else None, comment=self.test_comment, page=page, page_size=page_size, @@ -79,7 +79,7 @@ class TestEntityOrdering(TestService): def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs): tasks = self.api.tasks.get_all_ex( only_fields=self.only_fields, - order_by=order_by, + order_by=[order_by] if order_by else None, comment=self.test_comment, **kwargs, ).tasks