Add support for allow_public flag in get_all_ex endpoint

Add `last_changed_by` field on task updates
Fix reports support
This commit is contained in:
allegroai 2022-12-21 18:32:56 +02:00
parent c7cd949fd0
commit ae4c33fa0e
23 changed files with 256 additions and 76 deletions

View File

@ -79,3 +79,4 @@ class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@ -19,5 +19,7 @@ class EntitiesCountRequest(models.Base):
models = DictField()
pipelines = DictField()
datasets = DictField()
reports = DictField()
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@ -62,8 +62,9 @@ class ProjectsGetRequest(models.Base):
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
non_public = fields.BoolField(default=False) # legacy, use allow_public instead
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@ -318,3 +318,8 @@ class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1)
)
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)

View File

@ -58,8 +58,9 @@ class ModelBLL:
cls,
model_id: str,
company_id: str,
user_id: str,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
@ -74,7 +75,7 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, force_publish_task
model.task, company_id, user_id, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res

View File

@ -133,7 +133,7 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
@ -163,6 +163,7 @@ class QueueBLL(object):
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(

View File

@ -48,6 +48,7 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
@ -63,12 +64,13 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, update_cmds=update_cmds)
return update_task(task, user_id=user_id, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
@ -83,4 +85,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, user_id=user_id, update_cmds=delete_cmds)

View File

@ -63,6 +63,7 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
@ -94,13 +95,17 @@ class HyperParams:
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only
task,
user_id=user_id,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@classmethod
def edit_params(
cls,
company_id: str,
user_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
@ -129,7 +134,10 @@ class HyperParams:
] = value
return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only
task,
user_id=user_id,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@classmethod
@ -201,6 +209,7 @@ class HyperParams:
def edit_configuration(
cls,
company_id: str,
user_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
@ -219,11 +228,16 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, update_cmds=update_cmds)
return update_task(task, user_id=user_id, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
cls,
company_id: str,
user_id: str,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
@ -232,4 +246,4 @@ class HyperParams:
for name in set(configuration)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, user_id=user_id, update_cmds=delete_cmds)

View File

@ -33,7 +33,6 @@ from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
@ -137,6 +136,7 @@ class TaskBLL:
created=now,
last_update=now,
last_change=now,
last_changed_by=user,
**fields,
)
@ -268,6 +268,7 @@ class TaskBLL:
created=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
name=name or task.name,
comment=comment or task.comment,
parent=parent or parent_task,
@ -462,7 +463,12 @@ class TaskBLL:
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,
cls,
task: Task,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
):
try:
cls.dequeue(task, company_id)
@ -475,6 +481,7 @@ class TaskBLL:
new_status=task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
).execute(enqueue_status=None)
@classmethod

View File

@ -30,7 +30,11 @@ queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
task: Union[str, Task],
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> int:
"""
Deque and archive task
@ -52,7 +56,11 @@ def archive_task(
)
try:
TaskBLL.dequeue_and_change_status(
task, company_id, status_message, status_reason,
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
@ -63,11 +71,12 @@ def archive_task(
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
def unarchive_task(
task: str, company_id: str, status_message: str, status_reason: str,
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
@ -80,11 +89,16 @@ def unarchive_task(
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
def dequeue_task(
task_id: str, company_id: str, status_message: str, status_reason: str,
task_id: str,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
@ -92,7 +106,11 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task, company_id, status_message=status_message, status_reason=status_reason,
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
)
return 1, res
@ -100,6 +118,7 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
queue_id: str,
status_message: str,
status_reason: str,
@ -139,6 +158,7 @@ def enqueue_task(
status_message=status_message,
allow_same_state_transition=False,
force=force,
user_id=user_id,
).execute(enqueue_status=task.status)
try:
@ -151,6 +171,7 @@ def enqueue_task(
new_status=task.status,
force=True,
status_reason="failed enqueueing",
user_id=user_id,
).execute(enqueue_status=None)
raise
@ -220,6 +241,7 @@ def delete_task(
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
)
@ -319,6 +341,7 @@ def reset_task(
force=force,
status_reason="reset",
status_message="reset",
user_id=user_id,
).execute(
started=None,
completed=None,
@ -334,8 +357,9 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
force: bool,
publish_model_func: Callable[[str, str], Any] = None,
publish_model_func: Callable[[str, str, str], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
@ -363,7 +387,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id)
publish_model_func(model.id, company_id, user_id)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@ -372,6 +396,7 @@ def publish_task(
force=force,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
@ -384,7 +409,12 @@ def publish_task(
def stop_task(
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
task_id: str,
company_id: str,
user_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
@ -446,4 +476,5 @@ def stop_task(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute()

View File

@ -26,6 +26,7 @@ class ChangeStatusRequest(object):
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None)
user_id = attr.ib(type=str, default=None)
def execute(self, **kwargs):
current_status = self.current_status_override or self.task.status
@ -44,6 +45,7 @@ class ChangeStatusRequest(object):
status_changed=now,
last_update=now,
last_change=now,
last_changed_by=self.user_id,
)
if self.new_status == TaskStatus.queued:
@ -165,7 +167,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
@ -187,9 +189,9 @@ def get_task_for_update(
return task
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
now = datetime.utcnow()
last_updates = dict(last_change=now)
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)

View File

@ -196,6 +196,7 @@ class Task(AttributedDocument):
"$name",
"$id",
"$comment",
"$report",
"$models.input.model",
"$models.output.model",
"$script.repository",
@ -206,6 +207,7 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"report": 10,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
@ -228,7 +230,7 @@ class Task(AttributedDocument):
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment"),
pattern_fields=("name", "comment", "report"),
)
id = StringField(primary_key=True)
@ -242,6 +244,7 @@ class Task(AttributedDocument):
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
report = StringField()
created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField()
completed = DateTimeField()
@ -272,6 +275,7 @@ class Task(AttributedDocument):
enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True
)
last_changed_by = StringField()
def get_index_company(self) -> str:
"""

View File

@ -0,0 +1,17 @@
import logging as log
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import OperationFailure
def migrate_backend(db: Database):
"""
Drop task text index so that the new one including reports field is created
"""
tasks: Collection = db["task"]
try:
tasks.drop_index("backend-db.task.main_text_index")
except OperationFailure as ex:
log.warning(f"Could not delete task text index due to: {str(ex)}")
pass

View File

@ -241,6 +241,15 @@ get_all_ex {
default: false
}
}
"999.0": ${get_all_ex."2.20"} {
request.properties {
allow_public {
description: "Allow public models to be returned in the results"
type: boolean
default: true
}
}
}
}
get_all {
"2.1" {

View File

@ -176,4 +176,24 @@ get_entities_count {
}
}
}
"999.0": ${get_entities_count."2.22"} {
request.properties {
reports {
type: object
additionalProperties: true
description: Search criteria for reports
}
allow_public {
description: "Allow public entities to be counted in the results"
type: boolean
default: true
}
}
response.properties {
reports {
type: integer
description: The number of reports matching the criteria
}
}
}
}

View File

@ -611,6 +611,15 @@ get_all_ex {
default: false
}
}
"999.0": ${get_all_ex."2.20"} {
request.properties {
allow_public {
description: "Allow public projects to be returned in the results"
type: boolean
default: true
}
}
}
}
update {
"2.1" {

View File

@ -181,6 +181,15 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"999.0": ${get_all_ex."2.15"} {
request.properties {
allow_public {
description: "Allow public tasks to be returned in the results"
type: boolean
default: true
}
}
}
}
get_all {
"2.1" {

View File

@ -116,7 +116,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
allow_public=request.allow_public,
ret_params=ret_params,
)
conform_output_tags(call, models)
@ -482,6 +482,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
@ -500,6 +501,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial(
ModelBLL.publish_model,
company_id=company_id,
user_id=call.identity.user,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),

View File

@ -10,7 +10,7 @@ from apiserver.bll.project import ProjectBLL
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
@ -59,6 +59,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
"models": Model,
"pipelines": Project,
"datasets": Project,
"reports": Task,
}
ret = {}
for field, entity_cls in entity_classes.items():
@ -66,6 +67,10 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
if data is None:
continue
if field == "reports":
data["type"] = TaskType.report
data["include_subprojects"] = True
if request.active_users:
if entity_cls is Project:
requested_ids = data.get("id")
@ -75,7 +80,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
company=company,
users=request.active_users,
project_ids=requested_ids,
allow_public=True,
allow_public=request.allow_public,
)
if not ids:
ret[field] = 0
@ -85,11 +90,18 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
data["user"] = request.active_users
query = Q()
if entity_cls in (Project, Task) and not request.search_hidden:
if (
entity_cls in (Project, Task)
and field != "reports"
and not request.search_hidden
):
query &= Q(system_tags__ne=EntityVisibility.hidden.value)
ret[field] = entity_cls.get_count(
company=company, query_dict=data, query=query, allow_public=True,
company=company,
query_dict=data,
query=query,
allow_public=request.allow_public,
)
call.result.data = ret

View File

@ -100,7 +100,14 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data
conform_tag_fields(call, data)
allow_public = not request.non_public
allow_public = (
data["allow_public"]
if "allow_public" in data
else not data["non_public"]
if "non_public" in data
else request.allow_public
)
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]

View File

@ -142,7 +142,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest):
queue_bll.delete(
company_id=company_id, queue_id=req_model.queue, force=req_model.force
company_id=company_id,
user_id=call.identity.user,
queue_id=req_model.queue,
force=req_model.force,
)
call.result.data = {"deleted": 1}

View File

@ -51,9 +51,7 @@ update_fields = {
}
def _assert_report(
company_id, task_id, only_fields=None, requires_write_access=True
):
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
@ -72,9 +70,7 @@ def _assert_report(
@endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task,
company_id=company_id,
only_fields=("status",),
task_id=request.task, company_id=company_id, only_fields=("status",),
)
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
@ -196,7 +192,7 @@ def _get_task_metrics_from_request(
return task_metrics
@endpoint("reports.get_task_data", required_fields=[])
@endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
call_data = escape_execution_parameters(call)
process_include_subprojects(call_data)
@ -212,16 +208,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
unprepare_from_saved(call, tasks)
res = {"tasks": tasks, **ret_params}
if not (
request.debug_images
or request.plots
or request.scalar_metrics_iter_histogram
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
):
return res
task_ids = [task["id"] for task in tasks]
company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids
)
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events(
company_id=company,
@ -264,9 +256,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
)
task = _assert_report(
company_id=company_id,
task_id=request.task,
only_fields=("project",),
company_id=company_id, task_id=request.task, only_fields=("project",),
)
user_id = call.identity.user
project_name = request.project_name
@ -297,12 +287,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
"reports.publish", response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task
)
task = _assert_report(company_id=company_id, task_id=request.task)
updates = ChangeStatusRequest(
task=task,
company=company_id,
new_status=TaskStatus.published,
force=True,
status_reason="",
@ -315,9 +302,7 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task
)
task = _assert_report(company_id=company_id, task_id=request.task)
archived = task.update(
status_message=request.message,
status_reason="",
@ -331,9 +316,7 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task
)
task = _assert_report(company_id=company_id, task_id=request.task)
unarchived = task.update(
status_message=request.message,
status_reason="",
@ -359,9 +342,7 @@ def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
only_fields=("project",),
company_id=company_id, task_id=request.task, only_fields=("project",),
)
if (
task.status != TaskStatus.created

View File

@ -64,6 +64,7 @@ from apiserver.apimodels.tasks import (
ResetBatchItem,
CompletedRequest,
CompletedResponse,
GetAllReq,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@ -136,7 +137,7 @@ project_bll = ProjectBLL()
def set_task_status_from_call(
request: UpdateRequest, company_id, new_status=None, **set_fields
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
) -> dict:
fields_resolver = SetFieldsResolver(set_fields)
task = TaskBLL.get_task_with_access(
@ -171,6 +172,7 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute(**fields_resolver.get_fields(task))
@ -214,8 +216,8 @@ def _hidden_query(data: dict) -> Q:
return Q(system_tags__ne=EntityVisibility.hidden.value)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
@endpoint("tasks.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllReq):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
@ -226,7 +228,7 @@ def get_all_ex(call: APICall, company_id, _):
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
@ -291,6 +293,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_id=call.identity.user,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
@ -308,6 +311,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial(
stop_task,
company_id=company_id,
user_id=call.identity.user,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
@ -329,7 +333,8 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(
req_model,
company_id,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
@ -345,7 +350,8 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
res = StartedResponse(
**set_task_status_from_call(
req_model,
company_id,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
)
@ -359,7 +365,12 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
)
def failed(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed)
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.failed,
)
)
@ -368,7 +379,11 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
)
def close(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed)
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.closed)
)
@ -580,7 +595,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
company_id=company_id,
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(last_change=datetime.utcnow()),
injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
),
)
if updated_count:
new_project = updated_fields.get("project", task.project)
@ -613,7 +630,11 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
raise errors.bad_request.MissingTaskFields(
"Task has no script field", task=task.id
)
res = update_task(task, update_cmds=dict(script__requirements=requirements))
res = update_task(
task,
user_id=call.identity.user,
update_cmds=dict(script__requirements=requirements),
)
call.result.data_model = UpdateResponse(updated=res)
if res:
call.result.data_model.fields = {"script.requirements": requirements}
@ -648,7 +669,9 @@ def update_batch(call: APICall, company_id, _):
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
partial_update_dict.update(last_change=now)
partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
)
@ -725,7 +748,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
}
if fixed_fields:
now = datetime.utcnow()
last_change = dict(last_change=now)
last_change = dict(last_change=now, last_changed_by=call.identity.user)
if not set(fields).issubset(Task.user_set_allowed()):
last_change.update(last_update=now)
fields.update(**last_change)
@ -762,6 +785,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
@ -775,6 +799,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
@ -819,6 +844,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
user_id=call.identity.user,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
@ -834,6 +860,7 @@ def delete_configuration(
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
user_id=call.identity.user,
task_id=request.task,
configuration=request.configuration,
force=request.force,
@ -850,6 +877,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@ -874,6 +902,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial(
enqueue_task,
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@ -908,6 +937,7 @@ def dequeue(call: APICall, company_id, request: UpdateRequest):
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
)
@ -924,6 +954,7 @@ def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
dequeue_task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
),
@ -1019,6 +1050,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
for task in tasks:
archived += archive_task(
company_id=company_id,
user_id=call.identity.user,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
@ -1037,6 +1069,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
archive_task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
),
@ -1058,6 +1091,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
unarchive_task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
),
@ -1136,6 +1170,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
@ -1154,6 +1189,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial(
publish_task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
@ -1180,7 +1216,8 @@ def completed(call: APICall, company_id, request: CompletedRequest):
res = CompletedResponse(
**set_task_status_from_call(
request,
company_id,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
@ -1190,6 +1227,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
@ -1221,6 +1259,7 @@ def add_or_update_artifacts(
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
artifacts=request.artifacts,
force=True,
@ -1237,6 +1276,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
@ -1304,7 +1344,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
delete_names = {
@ -1317,5 +1357,5 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
if names
}
updated = task.update(last_change=datetime.utcnow(), **commands,)
updated = task.update(last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,)
return {"updated": updated}