API version bump

Update internal tests
Allow edit/delete task artifacts/hyperparams/configs using force flag
Improve lists query support for get_all calls
This commit is contained in:
allegroai 2021-01-05 17:57:58 +02:00
parent 563c451ac9
commit 0303c3525f
11 changed files with 207 additions and 33 deletions

View File

@ -1 +1 @@
__version__ = "2.11.0" __version__ = "2.12.0"

View File

@ -119,6 +119,7 @@ class CloneRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskRequest): class AddOrUpdateArtifactsRequest(TaskRequest):
artifacts = ListField([Artifact], validators=Length(minimum_value=1)) artifacts = ListField([Artifact], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArtifactId(models.Base): class ArtifactId(models.Base):
@ -130,6 +131,7 @@ class ArtifactId(models.Base):
class DeleteArtifactsRequest(TaskRequest): class DeleteArtifactsRequest(TaskRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1)) artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ResetRequest(UpdateRequest): class ResetRequest(UpdateRequest):
@ -166,6 +168,7 @@ class EditHyperParamsRequest(TaskRequest):
validators=Enum(*get_options(ReplaceHyperparams)), validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none, default=ReplaceHyperparams.none,
) )
force = BoolField(default=False)
class HyperParamKey(models.Base): class HyperParamKey(models.Base):
@ -177,6 +180,7 @@ class DeleteHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamKey] = ListField( hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1) [HyperParamKey], validators=Length(minimum_value=1)
) )
force = BoolField(default=False)
class GetConfigurationsRequest(MultiTaskRequest): class GetConfigurationsRequest(MultiTaskRequest):
@ -199,10 +203,12 @@ class EditConfigurationRequest(TaskRequest):
[Configuration], validators=Length(minimum_value=1) [Configuration], validators=Length(minimum_value=1)
) )
replace_configuration = BoolField(default=False) replace_configuration = BoolField(default=False)
force = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest): class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1)) configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArchiveRequest(MultiTaskRequest): class ArchiveRequest(MultiTaskRequest):

View File

@ -48,11 +48,17 @@ def artifacts_unprepare_from_saved(fields):
class Artifacts: class Artifacts:
@classmethod @classmethod
def add_or_update_artifacts( def add_or_update_artifacts(
cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact], cls,
company_id: str,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int: ) -> int:
with TimingContext("mongo", "update_artifacts"): with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update( task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=True company_id=company_id,
task_id=task_id,
force=force,
) )
artifacts = { artifacts = {
@ -68,11 +74,17 @@ class Artifacts:
@classmethod @classmethod
def delete_artifacts( def delete_artifacts(
cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId] cls,
company_id: str,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int: ) -> int:
with TimingContext("mongo", "delete_artifacts"): with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update( task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=True company_id=company_id,
task_id=task_id,
force=force,
) )
artifact_ids = [ artifact_ids = [

View File

@ -63,7 +63,11 @@ class HyperParams:
@classmethod @classmethod
def delete_params( def delete_params(
cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey] cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int: ) -> int:
with TimingContext("mongo", "delete_hyperparams"): with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams) properties_only = cls._normalize_params(hyperparams)
@ -71,6 +75,7 @@ class HyperParams:
company_id=company_id, company_id=company_id,
task_id=task_id, task_id=task_id,
allow_all_statuses=properties_only, allow_all_statuses=properties_only,
force=force,
) )
with_param, without_param = iterutils.partition( with_param, without_param = iterutils.partition(
@ -100,6 +105,7 @@ class HyperParams:
task_id: str, task_id: str,
hyperparams: Sequence[HyperParamItem], hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str, replace_hyperparams: str,
force: bool,
) -> int: ) -> int:
with TimingContext("mongo", "edit_hyperparams"): with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams) properties_only = cls._normalize_params(hyperparams)
@ -107,6 +113,7 @@ class HyperParams:
company_id=company_id, company_id=company_id,
task_id=task_id, task_id=task_id,
allow_all_statuses=properties_only, allow_all_statuses=properties_only,
force=force,
) )
update_cmds = dict() update_cmds = dict()
@ -198,9 +205,12 @@ class HyperParams:
task_id: str, task_id: str,
configuration: Sequence[Configuration], configuration: Sequence[Configuration],
replace_configuration: bool, replace_configuration: bool,
force: bool,
) -> int: ) -> int:
with TimingContext("mongo", "edit_configuration"): with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id) task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
update_cmds = dict() update_cmds = dict()
configuration = { configuration = {
@ -217,10 +227,12 @@ class HyperParams:
@classmethod @classmethod
def delete_configuration( def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str] cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
) -> int: ) -> int:
with TimingContext("mongo", "delete_configuration"): with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id) task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
delete_cmds = { delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1

View File

@ -174,7 +174,7 @@ def split_by(
def get_task_for_update( def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task: ) -> Task:
""" """
Loads only task id and return the task only if it is updatable (status == 'created') Loads only task id and return the task only if it is updatable (status == 'created')
@ -186,7 +186,10 @@ def get_task_for_update(
if allow_all_statuses: if allow_all_statuses:
return task return task
if task.status != TaskStatus.created: allowed_statuses = (
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
)
if task.status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus( raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status expected=TaskStatus.created, status=task.status
) )

View File

@ -104,8 +104,13 @@ class GetMixin(PropsMixin):
legacy_exclude_prefix = "-" legacy_exclude_prefix = "-"
_default = "in" _default = "in"
_ops = {"not": "nin"} _ops = {
"not": ("nin", False),
"all": ("all", True),
"and": ("all", True),
}
_next = _default _next = _default
_sticky = False
def __init__(self, legacy=False): def __init__(self, legacy=False):
self._legacy = legacy self._legacy = legacy
@ -116,13 +121,16 @@ class GetMixin(PropsMixin):
return self._default return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix): elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default self._next = self._default
return self._ops["not"] return self._ops["not"][0]
elif v.startswith(self.op_prefix): elif v.startswith(self.op_prefix):
self._next = self._ops.get(v[len(self.op_prefix) :], self._default) self._next, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky)
)
return None return None
next_ = self._next next_ = self._next
self._next = self._default if not self._sticky:
self._next = self._default
return next_ return next_
def value_transform(self, v): def value_transform(self, v):
@ -260,6 +268,7 @@ class GetMixin(PropsMixin):
- Exclusion can be specified by a leading "-" for each value (API versions <2.8) - Exclusion can be specified by a leading "-" for each value (API versions <2.8)
or by a preceding "__$not" value (operator) or by a preceding "__$not" value (operator)
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
""" """
if not isinstance(data, (list, tuple)): if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field) raise MakeGetAllQueryError("expected list", field)

View File

@ -1215,7 +1215,7 @@ delete {
} }
} }
archive { archive {
"2.11" { "2.12" {
description: """Archive tasks. description: """Archive tasks.
If a task is queued it will first be dequeued and then archived. If a task is queued it will first be dequeued and then archived.
""" """
@ -1629,6 +1629,10 @@ add_or_update_artifacts {
type: array type: array
items {"$ref": "#/definitions/artifact"} items {"$ref": "#/definitions/artifact"}
} }
force {
description: "If set to True then both new and running task artifacts can be edited. Otherwise only the new task ones. Default is False"
type: boolean
}
} }
} }
response { response {
@ -1658,6 +1662,10 @@ delete_artifacts {
type: array type: array
items {"$ref": "#/definitions/artifact_id"} items {"$ref": "#/definitions/artifact_id"}
} }
force {
description: "If set to True then both new and running task artifacts can be deleted. Otherwise only the new task ones. Default is False"
type: boolean
}
} }
} }
response { response {
@ -1740,8 +1748,22 @@ get_hyper_params {
type: object type: object
properties { properties {
params { params {
type: object
description: "Hyper parameters (keyed by task ID)" description: "Hyper parameters (keyed by task ID)"
type: array
items {
type: object
properties {
"task": {
description: "Task ID"
type: string
}
"hyperparams": {
description: "Hyper parameters"
type: array
items {"$ref": "#/definitions/params_item"}
}
}
}
} }
} }
} }
@ -1770,6 +1792,10 @@ edit_hyper_params {
'none' (the default value) - only the specific parameters will be updated or added""" 'none' (the default value) - only the specific parameters will be updated or added"""
"$ref": "#/definitions/replace_hyperparams_enum" "$ref": "#/definitions/replace_hyperparams_enum"
} }
force {
description: "If set to True then both new and running task hyper params can be edited. Otherwise only the new task ones. Default is False"
type: boolean
}
} }
} }
response { response {
@ -1799,6 +1825,10 @@ delete_hyper_params {
type: array type: array
items { "$ref": "#/definitions/param_key" } items { "$ref": "#/definitions/param_key" }
} }
force {
description: "If set to True then both new and running task hyper params can be deleted. Otherwise only the new task ones. Default is False"
type: boolean
}
} }
} }
response { response {
@ -1836,8 +1866,22 @@ get_configurations {
type: object type: object
properties { properties {
configurations { configurations {
type: object
description: "Configurations (keyed by task ID)" description: "Configurations (keyed by task ID)"
type: array
items {
type: object
properties {
"task" {
description: "Task ID"
type: string
}
"configuration" {
description: "Configuration list"
type: array
items {"$ref": "#/definitions/configuration_item"}
}
}
}
} }
} }
} }
@ -1861,8 +1905,19 @@ get_configuration_names {
type: object type: object
properties { properties {
configurations { configurations {
type: object
description: "Names of task configuration items (keyed by task ID)" description: "Names of task configuration items (keyed by task ID)"
type: object
properties {
task {
description: "Task ID"
type: string
}
names {
description: "Configuration names"
type: array
items {type: string}
}
}
} }
} }
} }
@ -1888,6 +1943,10 @@ edit_configuration {
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added" description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
type: boolean type: boolean
} }
force {
description: "If set to True then both new and running task configuration can be edited. Otherwise only the new task ones. Default is False"
type: boolean
}
} }
} }
response { response {
@ -1917,6 +1976,10 @@ delete_configuration {
type: array type: array
items { type: string } items { type: string }
} }
force {
description: "If set to True then both new and running task configuration can be deleted. Otherwise only the new task ones. Default is False"
type: boolean
}
} }
} }
response { response {

View File

@ -74,6 +74,7 @@ from apiserver.database.model.task.task import (
Script, Script,
DEFAULT_LAST_ITERATION, DEFAULT_LAST_ITERATION,
Execution, Execution,
ArtifactModes,
) )
from apiserver.database.utils import get_fields_attr, parse_from_call from apiserver.database.utils import get_fields_attr, parse_from_call
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
@ -642,6 +643,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
task_id=request.task, task_id=request.task,
hyperparams=request.hyperparams, hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams, replace_hyperparams=request.replace_hyperparams,
force=request.force,
) )
} }
@ -651,7 +653,10 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
with translate_errors_context(): with translate_errors_context():
call.result.data = { call.result.data = {
"deleted": HyperParams.delete_params( "deleted": HyperParams.delete_params(
company_id, task_id=request.task, hyperparams=request.hyperparams company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
) )
} }
@ -699,6 +704,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
task_id=request.task, task_id=request.task,
configuration=request.configuration, configuration=request.configuration,
replace_configuration=request.replace_configuration, replace_configuration=request.replace_configuration,
force=request.force,
) )
} }
@ -710,7 +716,10 @@ def delete_configuration(
with translate_errors_context(): with translate_errors_context():
call.result.data = { call.result.data = {
"deleted": HyperParams.delete_configuration( "deleted": HyperParams.delete_configuration(
company_id, task_id=request.task, configuration=request.configuration company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
) )
} }
@ -782,15 +791,20 @@ def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
request_data_model=UpdateRequest, request_data_model=UpdateRequest,
response_data_model=DequeueResponse, response_data_model=DequeueResponse,
) )
def dequeue(call: APICall, company_id, req_model: UpdateRequest): def dequeue(call: APICall, company_id, request: UpdateRequest):
task = TaskBLL.get_task_with_access( task = TaskBLL.get_task_with_access(
req_model.task, request.task,
company_id=company_id, company_id=company_id,
only=("id", "execution", "status", "project"), only=("id", "execution", "status", "project"),
requires_write_access=True, requires_write_access=True,
) )
res = DequeueResponse( res = DequeueResponse(
**TaskBLL.dequeue_and_change_status(task, company_id, req_model) **TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=request.status_message,
status_reason=request.status_reason,
)
) )
res.dequeued = 1 res.dequeued = 1
@ -846,8 +860,8 @@ def reset(call: APICall, company_id, request: ResetRequest):
updates.update( updates.update(
set__execution__artifacts={ set__execution__artifacts={
key: artifact key: artifact
for key, artifact in task.execution.artifacts for key, artifact in task.execution.artifacts.items()
if artifact.get("mode") == "input" if artifact.mode == ArtifactModes.input
} }
) )
@ -861,6 +875,9 @@ def reset(call: APICall, company_id, request: ResetRequest):
).execute(started=None, completed=None, published=None, **updates) ).execute(started=None, completed=None, published=None, **updates)
) )
# do not return artifacts since they are not serializable
res.fields.pop("execution.artifacts", None)
for key, value in api_results.items(): for key, value in api_results.items():
setattr(res, key, value) setattr(res, key, value)
@ -892,7 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
status_reason=request.status_reason, status_reason=request.status_reason,
system_tags=sorted( system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value} set(task.system_tags) | {EntityVisibility.archived.value}
) ),
) )
archived += 1 archived += 1
@ -1132,7 +1149,10 @@ def add_or_update_artifacts(
with translate_errors_context(): with translate_errors_context():
call.result.data = { call.result.data = {
"updated": Artifacts.add_or_update_artifacts( "updated": Artifacts.add_or_update_artifacts(
company_id=company_id, task_id=request.task, artifacts=request.artifacts company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
) )
} }
@ -1149,6 +1169,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
company_id=company_id, company_id=company_id,
task_id=request.task, task_id=request.task,
artifact_ids=request.artifacts, artifact_ids=request.artifacts,
force=request.force,
) )
} }

View File

@ -1,6 +1,7 @@
from operator import itemgetter from operator import itemgetter
from typing import Sequence from typing import Sequence
from apiserver.apierrors.errors.bad_request import InvalidTaskStatus
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
@ -58,14 +59,15 @@ class TestTasksArtifacts(TestService):
def test_artifacts_edit_delete(self): def test_artifacts_edit_delete(self):
artifacts = [ artifacts = [
dict(key="a", type="str", uri="test1"), dict(key="a", type="str", uri="test1", mode="input"),
dict(key="b", type="int", uri="test2"), dict(key="b", type="int", uri="test2"),
dict(key="c", type="int", uri="test3"),
] ]
task = self.new_task(execution={"artifacts": artifacts}) task = self.new_task(execution={"artifacts": artifacts})
# test add_or_update # test add_or_update
edit = [ edit = [
dict(key="a", type="str", uri="hello"), dict(key="a", type="str", uri="hello", mode="input"),
dict(key="c", type="int", uri="world"), dict(key="c", type="int", uri="world"),
] ]
res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit) res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
@ -78,6 +80,23 @@ class TestTasksArtifacts(TestService):
res = self.api.tasks.get_all_ex(id=[task]).tasks[0] res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res) self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
# test edit running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
self.api.tasks.reset(task=task)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: 1], res)
def _update_source(self, source: Sequence[dict], update: Sequence[dict]): def _update_source(self, source: Sequence[dict], update: Sequence[dict]):
dict1 = {s["key"]: s for s in source} dict1 = {s["key"]: s for s in source}
dict2 = {u["key"]: u for u in update} dict2 = {u["key"]: u for u in update}

View File

@ -118,11 +118,23 @@ class TestTasksHyperparams(TestService):
self.api.tasks.edit_hyper_params( self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")] task=task, hyperparams=[dict(section="test", name="x", value="123")]
) )
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")], force=True
)
# properties section can be edited/deleted in any task state without the flag
self.api.tasks.edit_hyper_params( self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="properties", name="x", value="123")] task=task, hyperparams=[dict(section="properties", name="x", value="123")]
) )
self.api.tasks.delete_hyper_params( self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="Properties")] task=task, hyperparams=[dict(section="properties")]
) )
@staticmethod @staticmethod
@ -204,7 +216,7 @@ class TestTasksHyperparams(TestService):
# delete # delete
new_to_delete = self._get_config_keys(new_config[1:]) new_to_delete = self._get_config_keys(new_config[1:])
res = self.api.tasks.delete_configuration( self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete task=task, configuration=new_to_delete
) )
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
@ -218,6 +230,23 @@ class TestTasksHyperparams(TestService):
finally: finally:
self.api.tasks.delete(task=new_task, force=True) self.api.tasks.delete(task=new_task, force=True)
# edit/delete of running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_configuration(
task=task, configuration=new_config
)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.edit_configuration(
task=task, configuration=new_config, force=True
)
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete, force=True
)
@staticmethod @staticmethod
def _get_config_keys(config: Sequence[dict]) -> List[dict]: def _get_config_keys(config: Sequence[dict]) -> List[dict]:
return [c["name"] for c in config] return [c["name"] for c in config]

View File

@ -9,7 +9,7 @@ log = config.logger(__file__)
class TestTasksEdit(TestService): class TestTasksEdit(TestService):
def setUp(self, **kwargs): def setUp(self, **kwargs):
super().setUp(version="2.9") super().setUp(version="2.12")
def new_task(self, **kwargs): def new_task(self, **kwargs):
self.update_missing( self.update_missing(