Support publish option in tasks.completed

This commit is contained in:
allegroai 2022-07-08 17:40:43 +03:00
parent adc1825843
commit 05357fe25e
4 changed files with 101 additions and 4 deletions

View File

@ -109,6 +109,14 @@ class SetRequirementsRequest(TaskRequest):
requirements = DictField(required=True)
class CompletedRequest(UpdateRequest):
publish = BoolField(default=False)
class CompletedResponse(UpdateResponse):
published = IntField(default=0)
class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)

View File

@ -2031,6 +2031,18 @@ completed {
} ${_references.status_change_request}
response: ${_definitions.update_response}
}
"999.0": ${completed."2.2"} {
request.properties.publish {
type: boolean
default: false
description: If set and the task is completed successfully then it is published
}
response.properties.published {
description: "Number of tasks published (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
ping {

View File

@ -62,6 +62,8 @@ from apiserver.apimodels.tasks import (
DequeueManyResponse,
ResetManyResponse,
ResetBatchItem,
CompletedRequest,
CompletedResponse,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@ -119,6 +121,7 @@ from apiserver.services.utils import (
unescape_dict_field,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields())
@ -1161,11 +1164,11 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
@endpoint(
"tasks.completed",
min_version="2.2",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
request_data_model=CompletedRequest,
response_data_model=CompletedResponse,
)
def completed(call: APICall, company_id, request: PublishRequest):
call.result.data_model = UpdateResponse(
def completed(call: APICall, company_id, request: CompletedRequest):
res = CompletedResponse(
**set_task_status_from_call(
request,
company_id,
@ -1174,6 +1177,22 @@ def completed(call: APICall, company_id, request: PublishRequest):
)
)
if res.updated and request.publish:
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
status_message=request.status_message,
)
res.published = publish_res.get("updated")
new_status = nested_get(publish_res, ("fields", "status"))
if new_status:
res.fields["status"] = new_status
call.result.data_model = res
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):

View File

@ -0,0 +1,58 @@
import time
from apiserver.tests.automated import TestService
class TestTasksRunning(TestService):
STATUS_STOPPED = "stopped"
STATUS_COMPLETED = "completed"
STATUS_PUBLISHED = "published"
STATUS_RUNNING = "in_progress"
def test_stop_regular_task(self):
task_id = self._create_running_task()
data = self.api.tasks.stop(task=task_id).fields
assert data.status == self.STATUS_STOPPED
def test_stop_regular_task_with_active_worker(self):
task_id = self._create_running_task()
worker_id = "worker1"
self.api.workers.register(worker=worker_id)
self.api.workers.status_report(
worker=worker_id, task=task_id, timestamp=int(time.time())
)
data = self.api.tasks.stop(task=task_id).fields
assert data.status == self.STATUS_RUNNING
assert data.status_message == "stopping"
def test_stop_development_task(self):
task_id = self._create_running_task(is_development=True)
data = self.api.tasks.stop(task=task_id).fields
assert data.status == self.STATUS_STOPPED
def test_completed_task(self):
task_id = self._create_running_task()
res = self.api.tasks.completed(task=task_id)
assert res.fields.status == self.STATUS_COMPLETED
assert res.updated == 1
assert res.published == 0
res = self.api.tasks.completed(task=task_id, publish=True)
assert res.fields.status == self.STATUS_PUBLISHED
assert res.updated == 1
assert res.published == 1
def _create_running_task(self, is_development=False):
task_input = dict(
name="task-1",
type="testing",
input=dict(mapping={}, view=dict()),
)
if is_development:
task_input["system_tags"] = ["development"]
task_id = self.create_temp("tasks", **task_input)
self.api.tasks.started(task=task_id)
return task_id