Add keep alive api

This commit is contained in:
allegroai 2019-07-09 00:02:05 +03:00
parent 76418eec1b
commit 61fb6553e6

View File

@ -23,23 +23,14 @@ from apimodels.tasks import (
SetRequirementsRequest,
TaskRequest,
DeleteRequest,
PingRequest,
)
from bll.event import EventBLL
from bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
split_by,
)
from bll.task import TaskBLL, ChangeStatusRequest, update_project_time, split_by
from database.errors import translate_errors_context
from database.model.model import Model
from database.model.task.output import Output
from database.model.task.task import (
Task,
TaskStatus,
Script,
DEFAULT_LAST_ITERATION,
)
from database.model.task.task import Task, TaskStatus, Script, DEFAULT_LAST_ITERATION
from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint
from timing_context import TimingContext
@ -48,14 +39,7 @@ from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
list_fields=(
"id",
"user",
"tags",
"type",
"status",
"project",
),
list_fields=("id", "user", "tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
@ -65,6 +49,9 @@ task_bll = TaskBLL()
event_bll = EventBLL()
TaskBLL.start_non_responsive_tasks_watchdog()
def set_task_status_from_call(
request: UpdateRequest, company_id, new_status=None, **kwargs
) -> dict:
@ -154,7 +141,10 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
def stopped(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(
req_model, company_id, new_status=TaskStatus.stopped, completed=datetime.utcnow()
req_model,
company_id,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
)
@ -167,7 +157,10 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
def started(call: APICall, company_id, req_model: UpdateRequest):
res = StartedResponse(
**set_task_status_from_call(
req_model, company_id, new_status=TaskStatus.in_progress, started=datetime.utcnow()
req_model,
company_id,
new_status=TaskStatus.in_progress,
started=datetime.utcnow(),
)
)
res.started = res.updated
@ -226,11 +219,6 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
try:
dpath.delete(fields, "script/requirements")
except dpath.exceptions.PathNotFound:
pass
# Make sure there are no duplicate tags
tags = fields.get("tags")
if tags:
@ -471,7 +459,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
api_results.update(attr.asdict(cleaned_up))
updates.update(
unset__script__requirements=1,
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
unset__output__result=1,
@ -680,3 +667,25 @@ def publish(call: APICall, company_id, req_model: PublishRequest):
status_message=req_model.status_message,
)
)
@endpoint(
"tasks.completed", min_version="2.2", request_data_model=UpdateRequest, response_data_model=UpdateResponse
)
def completed(call: APICall, company_id, request: PublishRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(
request,
company_id,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
)
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
)