API version bump

Update internal tests
Allow edit/delete task artifacts/hyperparams/configs using force flag
Improve lists query support for get_all calls
This commit is contained in:
allegroai 2021-01-05 17:57:58 +02:00
parent 563c451ac9
commit 0303c3525f
11 changed files with 207 additions and 33 deletions

View File

@ -1 +1 @@
__version__ = "2.11.0"
__version__ = "2.12.0"

View File

@ -119,6 +119,7 @@ class CloneRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskRequest):
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArtifactId(models.Base):
@ -130,6 +131,7 @@ class ArtifactId(models.Base):
class DeleteArtifactsRequest(TaskRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ResetRequest(UpdateRequest):
@ -166,6 +168,7 @@ class EditHyperParamsRequest(TaskRequest):
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
force = BoolField(default=False)
class HyperParamKey(models.Base):
@ -177,6 +180,7 @@ class DeleteHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
force = BoolField(default=False)
class GetConfigurationsRequest(MultiTaskRequest):
@ -199,10 +203,12 @@ class EditConfigurationRequest(TaskRequest):
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
force = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArchiveRequest(MultiTaskRequest):

View File

@ -48,11 +48,17 @@ def artifacts_unprepare_from_saved(fields):
class Artifacts:
@classmethod
def add_or_update_artifacts(
cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact],
cls,
company_id: str,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=True
company_id=company_id,
task_id=task_id,
force=force,
)
artifacts = {
@ -68,11 +74,17 @@ class Artifacts:
@classmethod
def delete_artifacts(
cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId]
cls,
company_id: str,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=True
company_id=company_id,
task_id=task_id,
force=force,
)
artifact_ids = [

View File

@ -63,7 +63,11 @@ class HyperParams:
@classmethod
def delete_params(
cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey]
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
@ -71,6 +75,7 @@ class HyperParams:
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
with_param, without_param = iterutils.partition(
@ -100,6 +105,7 @@ class HyperParams:
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
force: bool,
) -> int:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
@ -107,6 +113,7 @@ class HyperParams:
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
update_cmds = dict()
@ -198,9 +205,12 @@ class HyperParams:
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
update_cmds = dict()
configuration = {
@ -217,10 +227,12 @@ class HyperParams:
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str]
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
) -> int:
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1

View File

@ -174,7 +174,7 @@ def split_by(
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
@ -186,7 +186,10 @@ def get_task_for_update(
if allow_all_statuses:
return task
if task.status != TaskStatus.created:
allowed_statuses = (
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
)
if task.status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)

View File

@ -104,8 +104,13 @@ class GetMixin(PropsMixin):
legacy_exclude_prefix = "-"
_default = "in"
_ops = {"not": "nin"}
_ops = {
"not": ("nin", False),
"all": ("all", True),
"and": ("all", True),
}
_next = _default
_sticky = False
def __init__(self, legacy=False):
self._legacy = legacy
@ -116,13 +121,16 @@ class GetMixin(PropsMixin):
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"]
return self._ops["not"][0]
elif v.startswith(self.op_prefix):
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
self._next, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky)
)
return None
next_ = self._next
self._next = self._default
if not self._sticky:
self._next = self._default
return next_
def value_transform(self, v):
@ -260,6 +268,7 @@ class GetMixin(PropsMixin):
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
or by a preceding "__$not" value (operator)
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)

View File

@ -1215,7 +1215,7 @@ delete {
}
}
archive {
"2.11" {
"2.12" {
description: """Archive tasks.
If a task is queued it will first be dequeued and then archived.
"""
@ -1629,6 +1629,10 @@ add_or_update_artifacts {
type: array
items {"$ref": "#/definitions/artifact"}
}
force {
description: "If set to True then both new and running task artifacts can be edited. Otherwise only the new task ones. Default is False"
type: boolean
}
}
}
response {
@ -1658,6 +1662,10 @@ delete_artifacts {
type: array
items {"$ref": "#/definitions/artifact_id"}
}
force {
description: "If set to True then both new and running task artifacts can be deleted. Otherwise only the new task ones. Default is False"
type: boolean
}
}
}
response {
@ -1740,8 +1748,22 @@ get_hyper_params {
type: object
properties {
params {
type: object
description: "Hyper parameters (keyed by task ID)"
type: array
items {
type: object
properties {
"task": {
description: "Task ID"
type: string
}
"hyperparams": {
description: "Hyper parameters"
type: array
items {"$ref": "#/definitions/params_item"}
}
}
}
}
}
}
@ -1770,6 +1792,10 @@ edit_hyper_params {
'none' (the default value) - only the specific parameters will be updated or added"""
"$ref": "#/definitions/replace_hyperparams_enum"
}
force {
description: "If set to True then both new and running task hyper params can be edited. Otherwise only the new task ones. Default is False"
type: boolean
}
}
}
response {
@ -1799,6 +1825,10 @@ delete_hyper_params {
type: array
items { "$ref": "#/definitions/param_key" }
}
force {
description: "If set to True then both new and running task hyper params can be deleted. Otherwise only the new task ones. Default is False"
type: boolean
}
}
}
response {
@ -1836,8 +1866,22 @@ get_configurations {
type: object
properties {
configurations {
type: object
description: "Configurations (keyed by task ID)"
type: array
items {
type: object
properties {
"task" {
description: "Task ID"
type: string
}
"configuration" {
description: "Configuration list"
type: array
items {"$ref": "#/definitions/configuration_item"}
}
}
}
}
}
}
@ -1861,8 +1905,19 @@ get_configuration_names {
type: object
properties {
configurations {
type: object
description: "Names of task configuration items (keyed by task ID)"
type: object
properties {
task {
description: "Task ID"
type: string
}
names {
description: "Configuration names"
type: array
items {type: string}
}
}
}
}
}
@ -1888,6 +1943,10 @@ edit_configuration {
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
type: boolean
}
force {
description: "If set to True then both new and running task configuration can be edited. Otherwise only the new task ones. Default is False"
type: boolean
}
}
}
response {
@ -1917,6 +1976,10 @@ delete_configuration {
type: array
items { type: string }
}
force {
description: "If set to True then both new and running task configuration can be deleted. Otherwise only the new task ones. Default is False"
type: boolean
}
}
}
response {

View File

@ -74,6 +74,7 @@ from apiserver.database.model.task.task import (
Script,
DEFAULT_LAST_ITERATION,
Execution,
ArtifactModes,
)
from apiserver.database.utils import get_fields_attr, parse_from_call
from apiserver.service_repo import APICall, endpoint
@ -642,6 +643,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
@ -651,7 +653,10 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_params(
company_id, task_id=request.task, hyperparams=request.hyperparams
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
@ -699,6 +704,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
@ -710,7 +716,10 @@ def delete_configuration(
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id, task_id=request.task, configuration=request.configuration
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
@ -782,15 +791,20 @@ def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
request_data_model=UpdateRequest,
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, req_model: UpdateRequest):
def dequeue(call: APICall, company_id, request: UpdateRequest):
task = TaskBLL.get_task_with_access(
req_model.task,
request.task,
company_id=company_id,
only=("id", "execution", "status", "project"),
requires_write_access=True,
)
res = DequeueResponse(
**TaskBLL.dequeue_and_change_status(task, company_id, req_model)
**TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=request.status_message,
status_reason=request.status_reason,
)
)
res.dequeued = 1
@ -846,8 +860,8 @@ def reset(call: APICall, company_id, request: ResetRequest):
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts
if artifact.get("mode") == "input"
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
@ -861,6 +875,9 @@ def reset(call: APICall, company_id, request: ResetRequest):
).execute(started=None, completed=None, published=None, **updates)
)
# do not return artifacts since they are not serializable
res.fields.pop("execution.artifacts", None)
for key, value in api_results.items():
setattr(res, key, value)
@ -892,7 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
status_reason=request.status_reason,
system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value}
)
),
)
archived += 1
@ -1132,7 +1149,10 @@ def add_or_update_artifacts(
with translate_errors_context():
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id, task_id=request.task, artifacts=request.artifacts
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
)
}
@ -1149,6 +1169,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
)
}

View File

@ -1,6 +1,7 @@
from operator import itemgetter
from typing import Sequence
from apiserver.apierrors.errors.bad_request import InvalidTaskStatus
from apiserver.tests.automated import TestService
@ -58,14 +59,15 @@ class TestTasksArtifacts(TestService):
def test_artifacts_edit_delete(self):
artifacts = [
dict(key="a", type="str", uri="test1"),
dict(key="a", type="str", uri="test1", mode="input"),
dict(key="b", type="int", uri="test2"),
dict(key="c", type="int", uri="test3"),
]
task = self.new_task(execution={"artifacts": artifacts})
# test add_or_update
edit = [
dict(key="a", type="str", uri="hello"),
dict(key="a", type="str", uri="hello", mode="input"),
dict(key="c", type="int", uri="world"),
]
res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
@ -78,6 +80,23 @@ class TestTasksArtifacts(TestService):
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
# test edit running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
self.api.tasks.reset(task=task)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: 1], res)
def _update_source(self, source: Sequence[dict], update: Sequence[dict]):
dict1 = {s["key"]: s for s in source}
dict2 = {u["key"]: u for u in update}

View File

@ -118,11 +118,23 @@ class TestTasksHyperparams(TestService):
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")]
)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")], force=True
)
# properties section can be edited/deleted in any task state without the flag
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="properties", name="x", value="123")]
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="Properties")]
task=task, hyperparams=[dict(section="properties")]
)
@staticmethod
@ -204,7 +216,7 @@ class TestTasksHyperparams(TestService):
# delete
new_to_delete = self._get_config_keys(new_config[1:])
res = self.api.tasks.delete_configuration(
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
@ -218,6 +230,23 @@ class TestTasksHyperparams(TestService):
finally:
self.api.tasks.delete(task=new_task, force=True)
# edit/delete of running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_configuration(
task=task, configuration=new_config
)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.edit_configuration(
task=task, configuration=new_config, force=True
)
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete, force=True
)
@staticmethod
def _get_config_keys(config: Sequence[dict]) -> List[dict]:
return [c["name"] for c in config]

View File

@ -9,7 +9,7 @@ log = config.logger(__file__)
class TestTasksEdit(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.9")
super().setUp(version="2.12")
def new_task(self, **kwargs):
self.update_missing(