From 0303c3525f0e7e6a7522b41067b3128e905ea856 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 17:57:58 +0200 Subject: [PATCH] API version bump Update internal tests Allow edit/delete task artifacts/hyperparams/configs using force flag Improve lists query support for get_all calls --- apiserver/api_version.py | 2 +- apiserver/apimodels/tasks.py | 6 ++ apiserver/bll/task/artifacts.py | 20 ++++-- apiserver/bll/task/hyperparams.py | 20 ++++-- apiserver/bll/task/utils.py | 7 +- apiserver/database/model/base.py | 17 +++-- apiserver/schema/services/tasks.conf | 71 +++++++++++++++++-- apiserver/services/tasks.py | 39 +++++++--- .../tests/automated/test_task_artifacts.py | 23 +++++- .../tests/automated/test_task_hyperparams.py | 33 ++++++++- apiserver/tests/automated/test_tasks_edit.py | 2 +- 11 files changed, 207 insertions(+), 33 deletions(-) diff --git a/apiserver/api_version.py b/apiserver/api_version.py index 7f6646a..95a6d3a 100644 --- a/apiserver/api_version.py +++ b/apiserver/api_version.py @@ -1 +1 @@ -__version__ = "2.11.0" +__version__ = "2.12.0" diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 23d8647..7658018 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -119,6 +119,7 @@ class CloneRequest(TaskRequest): class AddOrUpdateArtifactsRequest(TaskRequest): artifacts = ListField([Artifact], validators=Length(minimum_value=1)) + force = BoolField(default=False) class ArtifactId(models.Base): @@ -130,6 +131,7 @@ class ArtifactId(models.Base): class DeleteArtifactsRequest(TaskRequest): artifacts = ListField([ArtifactId], validators=Length(minimum_value=1)) + force = BoolField(default=False) class ResetRequest(UpdateRequest): @@ -166,6 +168,7 @@ class EditHyperParamsRequest(TaskRequest): validators=Enum(*get_options(ReplaceHyperparams)), default=ReplaceHyperparams.none, ) + force = BoolField(default=False) class HyperParamKey(models.Base): @@ -177,6 +180,7 @@ class DeleteHyperParamsRequest(TaskRequest): hyperparams: Sequence[HyperParamKey] = ListField( [HyperParamKey], validators=Length(minimum_value=1) ) + force = BoolField(default=False) class GetConfigurationsRequest(MultiTaskRequest): @@ -199,10 +203,12 @@ class EditConfigurationRequest(TaskRequest): [Configuration], validators=Length(minimum_value=1) ) replace_configuration = BoolField(default=False) + force = BoolField(default=False) class DeleteConfigurationRequest(TaskRequest): configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1)) + force = BoolField(default=False) class ArchiveRequest(MultiTaskRequest): diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index c9f5da7..ee8c3db 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -48,11 +48,17 @@ def artifacts_unprepare_from_saved(fields): class Artifacts: @classmethod def add_or_update_artifacts( - cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact], + cls, + company_id: str, + task_id: str, + artifacts: Sequence[ApiArtifact], + force: bool, ) -> int: with TimingContext("mongo", "update_artifacts"): task = get_task_for_update( - company_id=company_id, task_id=task_id, allow_all_statuses=True + company_id=company_id, + task_id=task_id, + force=force, ) artifacts = { @@ -68,11 +74,17 @@ class Artifacts: @classmethod def delete_artifacts( - cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId] + cls, + company_id: str, + task_id: str, + artifact_ids: Sequence[ArtifactId], + force: bool, ) -> int: with TimingContext("mongo", "delete_artifacts"): task = get_task_for_update( - company_id=company_id, task_id=task_id, allow_all_statuses=True + company_id=company_id, + task_id=task_id, + force=force, ) artifact_ids = [ diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index a21181a..bd5c043 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -63,7 +63,11 @@ class HyperParams: @classmethod def delete_params( - cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey] + cls, + company_id: str, + task_id: str, + hyperparams: Sequence[HyperParamKey], + force: bool, ) -> int: with TimingContext("mongo", "delete_hyperparams"): properties_only = cls._normalize_params(hyperparams) @@ -71,6 +75,7 @@ class HyperParams: company_id=company_id, task_id=task_id, allow_all_statuses=properties_only, + force=force, ) with_param, without_param = iterutils.partition( @@ -100,6 +105,7 @@ class HyperParams: task_id: str, hyperparams: Sequence[HyperParamItem], replace_hyperparams: str, + force: bool, ) -> int: with TimingContext("mongo", "edit_hyperparams"): properties_only = cls._normalize_params(hyperparams) @@ -107,6 +113,7 @@ class HyperParams: company_id=company_id, task_id=task_id, allow_all_statuses=properties_only, + force=force, ) update_cmds = dict() @@ -198,9 +205,12 @@ class HyperParams: task_id: str, configuration: Sequence[Configuration], replace_configuration: bool, + force: bool, ) -> int: with TimingContext("mongo", "edit_configuration"): - task = get_task_for_update(company_id=company_id, task_id=task_id) + task = get_task_for_update( + company_id=company_id, task_id=task_id, force=force + ) update_cmds = dict() configuration = { @@ -217,10 +227,12 @@ class HyperParams: @classmethod def delete_configuration( - cls, company_id: str, task_id: str, configuration: Sequence[str] + cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool ) -> int: with TimingContext("mongo", "delete_configuration"): - task = get_task_for_update(company_id=company_id, task_id=task_id) + task = get_task_for_update( + company_id=company_id, task_id=task_id, force=force + ) delete_cmds = { f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 6e5cf31..82a3bc6 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -174,7 +174,7 @@ def split_by( def get_task_for_update( - company_id: str, task_id: str, allow_all_statuses: bool = False + company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False ) -> Task: """ Loads only task id and return the task only if it is updatable (status == 'created') @@ -186,7 +186,10 @@ def get_task_for_update( if allow_all_statuses: return task - if task.status != TaskStatus.created: + allowed_statuses = ( + [TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created] + ) + if task.status not in allowed_statuses: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status ) diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index 32ceeb8..66f0f39 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -104,8 +104,13 @@ class GetMixin(PropsMixin): legacy_exclude_prefix = "-" _default = "in" - _ops = {"not": "nin"} + _ops = { + "not": ("nin", False), + "all": ("all", True), + "and": ("all", True), + } _next = _default + _sticky = False def __init__(self, legacy=False): self._legacy = legacy @@ -116,13 +121,16 @@ class GetMixin(PropsMixin): return self._default elif self._legacy and v.startswith(self.legacy_exclude_prefix): self._next = self._default - return self._ops["not"] + return self._ops["not"][0] elif v.startswith(self.op_prefix): - self._next = self._ops.get(v[len(self.op_prefix) :], self._default) + self._next, self._sticky = self._ops.get( + v[len(self.op_prefix) :], (self._default, self._sticky) + ) return None next_ = self._next - self._next = self._default + if not self._sticky: + self._next = self._default return next_ def value_transform(self, v): @@ -260,6 +268,7 @@ class GetMixin(PropsMixin): - Exclusion can be specified by a leading "-" for each value (API versions <2.8) or by a preceding "__$not" value (operator) + - AND can be achieved using a preceding "__$all" or "__$and" value (operator) """ if not isinstance(data, (list, tuple)): raise MakeGetAllQueryError("expected list", field) diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index c597e14..fe392fb 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -1215,7 +1215,7 @@ delete { } } archive { - "2.11" { + "2.12" { description: """Archive tasks. If a task is queued it will first be dequeued and then archived. """ @@ -1629,6 +1629,10 @@ add_or_update_artifacts { type: array items {"$ref": "#/definitions/artifact"} } + force { + description: "If set to True then both new and running task artifacts can be edited. Otherwise only the new task ones. Default is False" + type: boolean + } } } response { @@ -1658,6 +1662,10 @@ delete_artifacts { type: array items {"$ref": "#/definitions/artifact_id"} } + force { + description: "If set to True then both new and running task artifacts can be deleted. Otherwise only the new task ones. Default is False" + type: boolean + } } } response { @@ -1740,8 +1748,22 @@ get_hyper_params { type: object properties { params { - type: object description: "Hyper parameters (keyed by task ID)" + type: array + items { + type: object + properties { + "task": { + description: "Task ID" + type: string + } + "hyperparams": { + description: "Hyper parameters" + type: array + items {"$ref": "#/definitions/params_item"} + } + } + } } } } @@ -1770,6 +1792,10 @@ edit_hyper_params { 'none' (the default value) - only the specific parameters will be updated or added""" "$ref": "#/definitions/replace_hyperparams_enum" } + force { + description: "If set to True then both new and running task hyper params can be edited. Otherwise only the new task ones. Default is False" + type: boolean + } } } response { @@ -1799,6 +1825,10 @@ delete_hyper_params { type: array items { "$ref": "#/definitions/param_key" } } + force { + description: "If set to True then both new and running task hyper params can be deleted. Otherwise only the new task ones. Default is False" + type: boolean + } } } response { @@ -1836,8 +1866,22 @@ get_configurations { type: object properties { configurations { - type: object description: "Configurations (keyed by task ID)" + type: array + items { + type: object + properties { + "task" { + description: "Task ID" + type: string + } + "configuration" { + description: "Configuration list" + type: array + items {"$ref": "#/definitions/configuration_item"} + } + } + } } } } @@ -1861,8 +1905,19 @@ get_configuration_names { type: object properties { configurations { - type: object description: "Names of task configuration items (keyed by task ID)" + type: object + properties { + task { + description: "Task ID" + type: string + } + names { + description: "Configuration names" + type: array + items {type: string} + } + } } } } @@ -1888,6 +1943,10 @@ edit_configuration { description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added" type: boolean } + force { + description: "If set to True then both new and running task configuration can be edited. Otherwise only the new task ones. Default is False" + type: boolean + } } } response { @@ -1917,6 +1976,10 @@ delete_configuration { type: array items { type: string } } + force { + description: "If set to True then both new and running task configuration can be deleted. Otherwise only the new task ones. Default is False" + type: boolean + } } } response { diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index d6a7ba3..77d9eb7 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -74,6 +74,7 @@ from apiserver.database.model.task.task import ( Script, DEFAULT_LAST_ITERATION, Execution, + ArtifactModes, ) from apiserver.database.utils import get_fields_attr, parse_from_call from apiserver.service_repo import APICall, endpoint @@ -642,6 +643,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest task_id=request.task, hyperparams=request.hyperparams, replace_hyperparams=request.replace_hyperparams, + force=request.force, ) } @@ -651,7 +653,10 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq with translate_errors_context(): call.result.data = { "deleted": HyperParams.delete_params( - company_id, task_id=request.task, hyperparams=request.hyperparams + company_id, + task_id=request.task, + hyperparams=request.hyperparams, + force=request.force, ) } @@ -699,6 +704,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ task_id=request.task, configuration=request.configuration, replace_configuration=request.replace_configuration, + force=request.force, ) } @@ -710,7 +716,10 @@ def delete_configuration( with translate_errors_context(): call.result.data = { "deleted": HyperParams.delete_configuration( - company_id, task_id=request.task, configuration=request.configuration + company_id, + task_id=request.task, + configuration=request.configuration, + force=request.force, ) } @@ -782,15 +791,20 @@ def enqueue(call: APICall, company_id, req_model: EnqueueRequest): request_data_model=UpdateRequest, response_data_model=DequeueResponse, ) -def dequeue(call: APICall, company_id, req_model: UpdateRequest): +def dequeue(call: APICall, company_id, request: UpdateRequest): task = TaskBLL.get_task_with_access( - req_model.task, + request.task, company_id=company_id, only=("id", "execution", "status", "project"), requires_write_access=True, ) res = DequeueResponse( - **TaskBLL.dequeue_and_change_status(task, company_id, req_model) + **TaskBLL.dequeue_and_change_status( + task, + company_id, + status_message=request.status_message, + status_reason=request.status_reason, + ) ) res.dequeued = 1 @@ -846,8 +860,8 @@ def reset(call: APICall, company_id, request: ResetRequest): updates.update( set__execution__artifacts={ key: artifact - for key, artifact in task.execution.artifacts - if artifact.get("mode") == "input" + for key, artifact in task.execution.artifacts.items() + if artifact.mode == ArtifactModes.input } ) @@ -861,6 +875,9 @@ def reset(call: APICall, company_id, request: ResetRequest): ).execute(started=None, completed=None, published=None, **updates) ) + # do not return artifacts since they are not serializable + res.fields.pop("execution.artifacts", None) + for key, value in api_results.items(): setattr(res, key, value) @@ -892,7 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): status_reason=request.status_reason, system_tags=sorted( set(task.system_tags) | {EntityVisibility.archived.value} - ) + ), ) archived += 1 @@ -1132,7 +1149,10 @@ def add_or_update_artifacts( with translate_errors_context(): call.result.data = { "updated": Artifacts.add_or_update_artifacts( - company_id=company_id, task_id=request.task, artifacts=request.artifacts + company_id=company_id, + task_id=request.task, + artifacts=request.artifacts, + force=request.force, ) } @@ -1149,6 +1169,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest) company_id=company_id, task_id=request.task, artifact_ids=request.artifacts, + force=request.force, ) } diff --git a/apiserver/tests/automated/test_task_artifacts.py b/apiserver/tests/automated/test_task_artifacts.py index 7dc0e7d..44e4f86 100644 --- a/apiserver/tests/automated/test_task_artifacts.py +++ b/apiserver/tests/automated/test_task_artifacts.py @@ -1,6 +1,7 @@ from operator import itemgetter from typing import Sequence +from apiserver.apierrors.errors.bad_request import InvalidTaskStatus from apiserver.tests.automated import TestService @@ -58,14 +59,15 @@ class TestTasksArtifacts(TestService): def test_artifacts_edit_delete(self): artifacts = [ - dict(key="a", type="str", uri="test1"), + dict(key="a", type="str", uri="test1", mode="input"), dict(key="b", type="int", uri="test2"), + dict(key="c", type="int", uri="test3"), ] task = self.new_task(execution={"artifacts": artifacts}) # test add_or_update edit = [ - dict(key="a", type="str", uri="hello"), + dict(key="a", type="str", uri="hello", mode="input"), dict(key="c", type="int", uri="world"), ] res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit) @@ -78,6 +80,23 @@ class TestTasksArtifacts(TestService): res = self.api.tasks.get_all_ex(id=[task]).tasks[0] self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res) + # test edit running task + self.api.tasks.started(task=task) + with self.api.raises(InvalidTaskStatus): + self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit) + self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True) + res = self.api.tasks.get_all_ex(id=[task]).tasks[0] + self._assertTaskArtifacts(artifacts, res) + with self.api.raises(InvalidTaskStatus): + self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}]) + self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True) + res = self.api.tasks.get_all_ex(id=[task]).tasks[0] + self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res) + + self.api.tasks.reset(task=task) + res = self.api.tasks.get_all_ex(id=[task]).tasks[0] + self._assertTaskArtifacts(artifacts[0: 1], res) + def _update_source(self, source: Sequence[dict], update: Sequence[dict]): dict1 = {s["key"]: s for s in source} dict2 = {u["key"]: u for u in update} diff --git a/apiserver/tests/automated/test_task_hyperparams.py b/apiserver/tests/automated/test_task_hyperparams.py index 104cd21..7ecb1c9 100644 --- a/apiserver/tests/automated/test_task_hyperparams.py +++ b/apiserver/tests/automated/test_task_hyperparams.py @@ -118,11 +118,23 @@ class TestTasksHyperparams(TestService): self.api.tasks.edit_hyper_params( task=task, hyperparams=[dict(section="test", name="x", value="123")] ) + with self.api.raises(InvalidTaskStatus): + self.api.tasks.delete_hyper_params( + task=task, hyperparams=[dict(section="test")] + ) + self.api.tasks.edit_hyper_params( + task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True + ) + self.api.tasks.delete_hyper_params( + task=task, hyperparams=[dict(section="test")], force=True + ) + + # properties section can be edited/deleted in any task state without the flag self.api.tasks.edit_hyper_params( task=task, hyperparams=[dict(section="properties", name="x", value="123")] ) self.api.tasks.delete_hyper_params( - task=task, hyperparams=[dict(section="Properties")] + task=task, hyperparams=[dict(section="properties")] ) @staticmethod @@ -204,7 +216,7 @@ class TestTasksHyperparams(TestService): # delete new_to_delete = self._get_config_keys(new_config[1:]) - res = self.api.tasks.delete_configuration( + self.api.tasks.delete_configuration( task=task, configuration=new_to_delete ) res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] @@ -218,6 +230,23 @@ class TestTasksHyperparams(TestService): finally: self.api.tasks.delete(task=new_task, force=True) + # edit/delete of running task + self.api.tasks.started(task=task) + with self.api.raises(InvalidTaskStatus): + self.api.tasks.edit_configuration( + task=task, configuration=new_config + ) + with self.api.raises(InvalidTaskStatus): + self.api.tasks.delete_configuration( + task=task, configuration=new_to_delete + ) + self.api.tasks.edit_configuration( + task=task, configuration=new_config, force=True + ) + self.api.tasks.delete_configuration( + task=task, configuration=new_to_delete, force=True + ) + @staticmethod def _get_config_keys(config: Sequence[dict]) -> List[dict]: return [c["name"] for c in config] diff --git a/apiserver/tests/automated/test_tasks_edit.py b/apiserver/tests/automated/test_tasks_edit.py index 76ab677..59b5eb5 100644 --- a/apiserver/tests/automated/test_tasks_edit.py +++ b/apiserver/tests/automated/test_tasks_edit.py @@ -9,7 +9,7 @@ log = config.logger(__file__) class TestTasksEdit(TestService): def setUp(self, **kwargs): - super().setUp(version="2.9") + super().setUp(version="2.12") def new_task(self, **kwargs): self.update_missing(