diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index 6c5a113..1844e8a 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -79,3 +79,4 @@ class AddOrUpdateMetadataRequest(AddOrUpdateMetadata): class ModelsGetRequest(models.Base): include_stats = fields.BoolField(default=False) + allow_public = fields.BoolField(default=True) diff --git a/apiserver/apimodels/organization.py b/apiserver/apimodels/organization.py index f9b72d0..46bcd1f 100644 --- a/apiserver/apimodels/organization.py +++ b/apiserver/apimodels/organization.py @@ -19,5 +19,7 @@ class EntitiesCountRequest(models.Base): models = DictField() pipelines = DictField() datasets = DictField() + reports = DictField() active_users = fields.ListField(str) search_hidden = fields.BoolField(default=False) + allow_public = fields.BoolField(default=True) diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 58f2e46..755829a 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -62,8 +62,9 @@ class ProjectsGetRequest(models.Base): include_stats_filter = DictField() stats_with_children = fields.BoolField(default=True) stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active) - non_public = fields.BoolField(default=False) + non_public = fields.BoolField(default=False) # legacy, use allow_public instead active_users = fields.ListField(str) check_own_contents = fields.BoolField(default=False) shallow_search = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False) + allow_public = fields.BoolField(default=True) diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index cf372c4..bd64d1f 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -318,3 +318,8 @@ class DeleteModelsRequest(TaskRequest): models: Sequence[ModelItemKey] = ListField( [ModelItemKey], validators=Length(minimum_value=1) ) + + +class GetAllReq(models.Base): + allow_public = BoolField(default=True) + search_hidden = BoolField(default=False) diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py index f5e234a..85f129a 100644 --- a/apiserver/bll/model/__init__.py +++ b/apiserver/bll/model/__init__.py @@ -58,8 +58,9 @@ class ModelBLL: cls, model_id: str, company_id: str, + user_id: str, force_publish_task: bool = False, - publish_task_func: Callable[[str, str, bool], dict] = None, + publish_task_func: Callable[[str, str, str, bool], dict] = None, ) -> Tuple[int, ModelTaskPublishResponse]: model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id) if model.ready: @@ -74,7 +75,7 @@ class ModelBLL: ) if task and task.status != TaskStatus.published: task_publish_res = publish_task_func( - model.task, company_id, force_publish_task + model.task, company_id, user_id, force_publish_task ) published_task = ModelTaskPublishResponse( id=model.task, data=task_publish_res diff --git a/apiserver/bll/queue/queue_bll.py b/apiserver/bll/queue/queue_bll.py index c385936..88bff5e 100644 --- a/apiserver/bll/queue/queue_bll.py +++ b/apiserver/bll/queue/queue_bll.py @@ -133,7 +133,7 @@ class QueueBLL(object): self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",)) return Queue.safe_update(company_id, queue_id, update_fields) - def delete(self, company_id: str, queue_id: str, force: bool) -> None: + def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None: """ Delete the queue :raise errors.bad_request.InvalidQueueId: if the queue is not found @@ -163,6 +163,7 @@ class QueueBLL(object): new_status=task.enqueue_status or TaskStatus.created, status_reason="Queue was deleted", status_message="", + user_id=user_id, ).execute(enqueue_status=None) except Exception as ex: log.exception( diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index 0305f5e..7ff19bf 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -48,6 +48,7 @@ class Artifacts: def add_or_update_artifacts( cls, company_id: str, + user_id: str, task_id: str, artifacts: Sequence[ApiArtifact], force: bool, @@ -63,12 +64,13 @@ class Artifacts: f"set__execution__artifacts__{mongoengine_safe(name)}": value for name, value in artifacts.items() } - return update_task(task, update_cmds=update_cmds) + return update_task(task, user_id=user_id, update_cmds=update_cmds) @classmethod def delete_artifacts( cls, company_id: str, + user_id: str, task_id: str, artifact_ids: Sequence[ArtifactId], force: bool, @@ -83,4 +85,4 @@ class Artifacts: f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) } - return update_task(task, update_cmds=delete_cmds) + return update_task(task, user_id=user_id, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index f8d66c9..4159b87 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -63,6 +63,7 @@ class HyperParams: def delete_params( cls, company_id: str, + user_id: str, task_id: str, hyperparams: Sequence[HyperParamKey], force: bool, @@ -94,13 +95,17 @@ class HyperParams: delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 return update_task( - task, update_cmds=delete_cmds, set_last_update=not properties_only + task, + user_id=user_id, + update_cmds=delete_cmds, + set_last_update=not properties_only, ) @classmethod def edit_params( cls, company_id: str, + user_id: str, task_id: str, hyperparams: Sequence[HyperParamItem], replace_hyperparams: str, @@ -129,7 +134,10 @@ class HyperParams: ] = value return update_task( - task, update_cmds=update_cmds, set_last_update=not properties_only + task, + user_id=user_id, + update_cmds=update_cmds, + set_last_update=not properties_only, ) @classmethod @@ -201,6 +209,7 @@ class HyperParams: def edit_configuration( cls, company_id: str, + user_id: str, task_id: str, configuration: Sequence[Configuration], replace_configuration: bool, @@ -219,11 +228,16 @@ class HyperParams: for name, value in configuration.items(): update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value - return update_task(task, update_cmds=update_cmds) + return update_task(task, user_id=user_id, update_cmds=update_cmds) @classmethod def delete_configuration( - cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool + cls, + company_id: str, + user_id: str, + task_id: str, + configuration: Sequence[str], + force: bool, ) -> int: task = get_task_for_update(company_id=company_id, task_id=task_id, force=force) @@ -232,4 +246,4 @@ class HyperParams: for name in set(configuration) } - return update_task(task, update_cmds=delete_cmds) + return update_task(task, user_id=user_id, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index ce2b806..939f62f 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -33,7 +33,6 @@ from apiserver.database.model import EntityVisibility from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.es_factory import es_factory from apiserver.redis_manager import redman -from apiserver.service_repo import APICall from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save @@ -137,6 +136,7 @@ class TaskBLL: created=now, last_update=now, last_change=now, + last_changed_by=user, **fields, ) @@ -268,6 +268,7 @@ class TaskBLL: created=now, last_update=now, last_change=now, + last_changed_by=user_id, name=name or task.name, comment=comment or task.comment, parent=parent or parent_task, @@ -462,7 +463,12 @@ class TaskBLL: @classmethod def dequeue_and_change_status( - cls, task: Task, company_id: str, status_message: str, status_reason: str, + cls, + task: Task, + company_id: str, + user_id: str, + status_message: str, + status_reason: str, ): try: cls.dequeue(task, company_id) @@ -475,6 +481,7 @@ class TaskBLL: new_status=task.enqueue_status or TaskStatus.created, status_reason=status_reason, status_message=status_message, + user_id=user_id, ).execute(enqueue_status=None) @classmethod diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py index 5fc3a1b..fd92a57 100644 --- a/apiserver/bll/task/task_operations.py +++ b/apiserver/bll/task/task_operations.py @@ -30,7 +30,11 @@ queue_bll = QueueBLL() def archive_task( - task: Union[str, Task], company_id: str, status_message: str, status_reason: str, + task: Union[str, Task], + company_id: str, + user_id: str, + status_message: str, + status_reason: str, ) -> int: """ Deque and archive task @@ -52,7 +56,11 @@ def archive_task( ) try: TaskBLL.dequeue_and_change_status( - task, company_id, status_message, status_reason, + task, + company_id=company_id, + user_id=user_id, + status_message=status_message, + status_reason=status_reason, ) except APIError: # dequeue may fail if the task was not enqueued @@ -63,11 +71,12 @@ def archive_task( status_reason=status_reason, add_to_set__system_tags=EntityVisibility.archived.value, last_change=datetime.utcnow(), + last_changed_by=user_id, ) def unarchive_task( - task: str, company_id: str, status_message: str, status_reason: str, + task: str, company_id: str, user_id: str, status_message: str, status_reason: str, ) -> int: """ Unarchive task. Return 1 if successful @@ -80,11 +89,16 @@ def unarchive_task( status_reason=status_reason, pull__system_tags=EntityVisibility.archived.value, last_change=datetime.utcnow(), + last_changed_by=user_id, ) def dequeue_task( - task_id: str, company_id: str, status_message: str, status_reason: str, + task_id: str, + company_id: str, + user_id: str, + status_message: str, + status_reason: str, ) -> Tuple[int, dict]: query = dict(id=task_id, company=company_id) task = Task.get_for_writing(**query) @@ -92,7 +106,11 @@ def dequeue_task( raise errors.bad_request.InvalidTaskId(**query) res = TaskBLL.dequeue_and_change_status( - task, company_id, status_message=status_message, status_reason=status_reason, + task, + company_id=company_id, + user_id=user_id, + status_message=status_message, + status_reason=status_reason, ) return 1, res @@ -100,6 +118,7 @@ def dequeue_task( def enqueue_task( task_id: str, company_id: str, + user_id: str, queue_id: str, status_message: str, status_reason: str, @@ -139,6 +158,7 @@ def enqueue_task( status_message=status_message, allow_same_state_transition=False, force=force, + user_id=user_id, ).execute(enqueue_status=task.status) try: @@ -151,6 +171,7 @@ def enqueue_task( new_status=task.status, force=True, status_reason="failed enqueueing", + user_id=user_id, ).execute(enqueue_status=None) raise @@ -220,6 +241,7 @@ def delete_task( TaskBLL.dequeue_and_change_status( task, company_id=company_id, + user_id=user_id, status_message=status_message, status_reason=status_reason, ) @@ -319,6 +341,7 @@ def reset_task( force=force, status_reason="reset", status_message="reset", + user_id=user_id, ).execute( started=None, completed=None, @@ -334,8 +357,9 @@ def reset_task( def publish_task( task_id: str, company_id: str, + user_id: str, force: bool, - publish_model_func: Callable[[str, str], Any] = None, + publish_model_func: Callable[[str, str, str], Any] = None, status_message: str = "", status_reason: str = "", ) -> dict: @@ -363,7 +387,7 @@ def publish_task( .first() ) if model and not model.ready: - publish_model_func(model.id, company_id) + publish_model_func(model.id, company_id, user_id) # set task status to published, and update (or set) it's new output (view and models) return ChangeStatusRequest( @@ -372,6 +396,7 @@ def publish_task( force=force, status_reason=status_reason, status_message=status_message, + user_id=user_id, ).execute(published=datetime.utcnow(), output=output) except Exception as ex: @@ -384,7 +409,12 @@ def publish_task( def stop_task( - task_id: str, company_id: str, user_name: str, status_reason: str, force: bool, + task_id: str, + company_id: str, + user_id: str, + user_name: str, + status_reason: str, + force: bool, ) -> dict: """ Stop a running task. Requires task status 'in_progress' and @@ -446,4 +476,5 @@ def stop_task( status_reason=status_reason, status_message=status_message, force=force, + user_id=user_id, ).execute() diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 070309d..418960c 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -26,6 +26,7 @@ class ChangeStatusRequest(object): force = attr.ib(type=bool, default=False) allow_same_state_transition = attr.ib(type=bool, default=True) current_status_override = attr.ib(default=None) + user_id = attr.ib(type=str, default=None) def execute(self, **kwargs): current_status = self.current_status_override or self.task.status @@ -44,6 +45,7 @@ class ChangeStatusRequest(object): status_changed=now, last_update=now, last_change=now, + last_changed_by=self.user_id, ) if self.new_status == TaskStatus.queued: @@ -165,7 +167,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]): def get_task_for_update( - company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False + company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False ) -> Task: """ Loads only task id and return the task only if it is updatable (status == 'created') @@ -187,9 +189,9 @@ def get_task_for_update( return task -def update_task(task: Task, update_cmds: dict, set_last_update: bool = True): +def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True): now = datetime.utcnow() - last_updates = dict(last_change=now) + last_updates = dict(last_change=now, last_changed_by=user_id) if set_last_update: last_updates.update(last_update=now) return task.update(**update_cmds, **last_updates) diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 99911f5..b199120 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -196,6 +196,7 @@ class Task(AttributedDocument): "$name", "$id", "$comment", + "$report", "$models.input.model", "$models.output.model", "$script.repository", @@ -206,6 +207,7 @@ class Task(AttributedDocument): "name": 10, "id": 10, "comment": 10, + "report": 10, "models.output.model": 2, "models.input.model": 2, "script.repository": 1, @@ -228,7 +230,7 @@ class Task(AttributedDocument): ), range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"), datetime_fields=("status_changed", "last_update"), - pattern_fields=("name", "comment"), + pattern_fields=("name", "comment", "report"), ) id = StringField(primary_key=True) @@ -242,6 +244,7 @@ class Task(AttributedDocument): status_message = StringField(user_set_allowed=True) status_changed = DateTimeField() comment = StringField(user_set_allowed=True) + report = StringField() created = DateTimeField(required=True, user_set_allowed=True) started = DateTimeField() completed = DateTimeField() @@ -272,6 +275,7 @@ class Task(AttributedDocument): enqueue_status = StringField( choices=get_options(TaskStatus), exclude_by_default=True ) + last_changed_by = StringField() def get_index_company(self) -> str: """ diff --git a/apiserver/mongo/migrations/1_9_0.py b/apiserver/mongo/migrations/1_9_0.py new file mode 100644 index 0000000..0fdcf0a --- /dev/null +++ b/apiserver/mongo/migrations/1_9_0.py @@ -0,0 +1,17 @@ +import logging as log + +from pymongo.collection import Collection +from pymongo.database import Database +from pymongo.errors import OperationFailure + + +def migrate_backend(db: Database): + """ + Drop task text index so that the new one including reports field is created + """ + tasks: Collection = db["task"] + try: + tasks.drop_index("backend-db.task.main_text_index") + except OperationFailure as ex: + log.warning(f"Could not delete task text index due to: {str(ex)}") + pass diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index b1853e2..6aae3ab 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -241,6 +241,15 @@ get_all_ex { default: false } } + "999.0": ${get_all_ex."2.20"} { + request.properties { + allow_public { + description: "Allow public models to be returned in the results" + type: boolean + default: true + } + } + } } get_all { "2.1" { diff --git a/apiserver/schema/services/organization.conf b/apiserver/schema/services/organization.conf index b0c731a..9e7df89 100644 --- a/apiserver/schema/services/organization.conf +++ b/apiserver/schema/services/organization.conf @@ -176,4 +176,24 @@ get_entities_count { } } } + "999.0": ${get_entities_count."2.22"} { + request.properties { + reports { + type: object + additionalProperties: true + description: Search criteria for reports + } + allow_public { + description: "Allow public entities to be counted in the results" + type: boolean + default: true + } + } + response.properties { + reports { + type: integer + description: The number of reports matching the criteria + } + } + } } diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 6b87936..4c4c6f0 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -611,6 +611,15 @@ get_all_ex { default: false } } + "999.0": ${get_all_ex."2.20"} { + request.properties { + allow_public { + description: "Allow public projects to be returned in the results" + type: boolean + default: true + } + } + } } update { "2.1" { diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 729548f..ce5f9dd 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -181,6 +181,15 @@ get_all_ex { description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data" } } + "999.0": ${get_all_ex."2.15"} { + request.properties { + allow_public { + description: "Allow public tasks to be returned in the results" + type: boolean + default: true + } + } + } } get_all { "2.1" { diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 7115329..d91ba4c 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -116,7 +116,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): models = Model.get_many_with_join( company=company_id, query_dict=call.data, - allow_public=True, + allow_public=request.allow_public, ret_params=ret_params, ) conform_output_tags(call, models) @@ -482,6 +482,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest): updated, published_task = ModelBLL.publish_model( model_id=request.model, company_id=company_id, + user_id=call.identity.user, force_publish_task=request.force_publish_task, publish_task_func=publish_task if request.publish_task else None, ) @@ -500,6 +501,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest): func=partial( ModelBLL.publish_model, company_id=company_id, + user_id=call.identity.user, force_publish_task=request.force_publish_task, publish_task_func=publish_task if request.publish_task else None, ), diff --git a/apiserver/services/organization.py b/apiserver/services/organization.py index b068e3e..30395c8 100644 --- a/apiserver/services/organization.py +++ b/apiserver/services/organization.py @@ -10,7 +10,7 @@ from apiserver.bll.project import ProjectBLL from apiserver.database.model import User, AttributedDocument, EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project -from apiserver.database.model.task.task import Task +from apiserver.database.model.task.task import Task, TaskType from apiserver.service_repo import endpoint, APICall from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response @@ -59,6 +59,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest): "models": Model, "pipelines": Project, "datasets": Project, + "reports": Task, } ret = {} for field, entity_cls in entity_classes.items(): @@ -66,6 +67,10 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest): if data is None: continue + if field == "reports": + data["type"] = TaskType.report + data["include_subprojects"] = True + if request.active_users: if entity_cls is Project: requested_ids = data.get("id") @@ -75,7 +80,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest): company=company, users=request.active_users, project_ids=requested_ids, - allow_public=True, + allow_public=request.allow_public, ) if not ids: ret[field] = 0 @@ -85,11 +90,18 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest): data["user"] = request.active_users query = Q() - if entity_cls in (Project, Task) and not request.search_hidden: + if ( + entity_cls in (Project, Task) + and field != "reports" + and not request.search_hidden + ): query &= Q(system_tags__ne=EntityVisibility.hidden.value) ret[field] = entity_cls.get_count( - company=company, query_dict=data, query=query, allow_public=True, + company=company, + query_dict=data, + query=query, + allow_public=request.allow_public, ) call.result.data = ret diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 5e7bbd1..1bb4118 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -100,7 +100,14 @@ def _adjust_search_parameters(data: dict, shallow_search: bool): def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): data = call.data conform_tag_fields(call, data) - allow_public = not request.non_public + allow_public = ( + data["allow_public"] + if "allow_public" in data + else not data["non_public"] + if "non_public" in data + else request.allow_public + ) + requested_ids = data.get("id") if isinstance(requested_ids, str): requested_ids = [requested_ids] diff --git a/apiserver/services/queues.py b/apiserver/services/queues.py index a78a634..3ba30d1 100644 --- a/apiserver/services/queues.py +++ b/apiserver/services/queues.py @@ -142,7 +142,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest): @endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest) def delete(call: APICall, company_id, req_model: DeleteRequest): queue_bll.delete( - company_id=company_id, queue_id=req_model.queue, force=req_model.force + company_id=company_id, + user_id=call.identity.user, + queue_id=req_model.queue, + force=req_model.force, ) call.result.data = {"deleted": 1} diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 5cea332..220acc9 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -51,9 +51,7 @@ update_fields = { } -def _assert_report( - company_id, task_id, only_fields=None, requires_write_access=True -): +def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True): if only_fields and "type" not in only_fields: only_fields += ("type",) @@ -72,9 +70,7 @@ def _assert_report( @endpoint("reports.update", response_data_model=UpdateResponse) def update_report(call: APICall, company_id: str, request: UpdateReportRequest): task = _assert_report( - task_id=request.task, - company_id=company_id, - only_fields=("status",), + task_id=request.task, company_id=company_id, only_fields=("status",), ) if task.status != TaskStatus.created: raise errors.bad_request.InvalidTaskStatus( @@ -196,7 +192,7 @@ def _get_task_metrics_from_request( return task_metrics -@endpoint("reports.get_task_data", required_fields=[]) +@endpoint("reports.get_task_data") def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): call_data = escape_execution_parameters(call) process_include_subprojects(call_data) @@ -212,16 +208,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): unprepare_from_saved(call, tasks) res = {"tasks": tasks, **ret_params} if not ( - request.debug_images - or request.plots - or request.scalar_metrics_iter_histogram + request.debug_images or request.plots or request.scalar_metrics_iter_histogram ): return res task_ids = [task["id"] for task in tasks] - company, tasks_or_models = _get_task_or_model_index_company( - company_id, task_ids - ) + company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) if request.debug_images: result = event_bll.debug_images_iterator.get_task_events( company_id=company, @@ -264,9 +256,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest): ) task = _assert_report( - company_id=company_id, - task_id=request.task, - only_fields=("project",), + company_id=company_id, task_id=request.task, only_fields=("project",), ) user_id = call.identity.user project_name = request.project_name @@ -297,12 +287,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest): "reports.publish", response_data_model=UpdateResponse, ) def publish(call: APICall, company_id, request: PublishReportRequest): - task = _assert_report( - company_id=company_id, task_id=request.task - ) + task = _assert_report(company_id=company_id, task_id=request.task) updates = ChangeStatusRequest( task=task, - company=company_id, new_status=TaskStatus.published, force=True, status_reason="", @@ -315,9 +302,7 @@ def publish(call: APICall, company_id, request: PublishReportRequest): @endpoint("reports.archive") def archive(call: APICall, company_id, request: ArchiveReportRequest): - task = _assert_report( - company_id=company_id, task_id=request.task - ) + task = _assert_report(company_id=company_id, task_id=request.task) archived = task.update( status_message=request.message, status_reason="", @@ -331,9 +316,7 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest): @endpoint("reports.unarchive") def unarchive(call: APICall, company_id, request: ArchiveReportRequest): - task = _assert_report( - company_id=company_id, task_id=request.task - ) + task = _assert_report(company_id=company_id, task_id=request.task) unarchived = task.update( status_message=request.message, status_reason="", @@ -359,9 +342,7 @@ def unarchive(call: APICall, company_id, request: ArchiveReportRequest): @endpoint("reports.delete") def delete(call: APICall, company_id, request: DeleteReportRequest): task = _assert_report( - company_id=company_id, - task_id=request.task, - only_fields=("project",), + company_id=company_id, task_id=request.task, only_fields=("project",), ) if ( task.status != TaskStatus.created diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 3992163..e39fb41 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -64,6 +64,7 @@ from apiserver.apimodels.tasks import ( ResetBatchItem, CompletedRequest, CompletedResponse, + GetAllReq, ) from apiserver.bll.event import EventBLL from apiserver.bll.model import ModelBLL @@ -136,7 +137,7 @@ project_bll = ProjectBLL() def set_task_status_from_call( - request: UpdateRequest, company_id, new_status=None, **set_fields + request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields ) -> dict: fields_resolver = SetFieldsResolver(set_fields) task = TaskBLL.get_task_with_access( @@ -171,6 +172,7 @@ def set_task_status_from_call( status_reason=status_reason, status_message=status_message, force=force, + user_id=user_id, ).execute(**fields_resolver.get_fields(task)) @@ -214,8 +216,8 @@ def _hidden_query(data: dict) -> Q: return Q(system_tags__ne=EntityVisibility.hidden.value) -@endpoint("tasks.get_all_ex", required_fields=[]) -def get_all_ex(call: APICall, company_id, _): +@endpoint("tasks.get_all_ex") +def get_all_ex(call: APICall, company_id, request: GetAllReq): conform_tag_fields(call, call.data) call_data = escape_execution_parameters(call) @@ -226,7 +228,7 @@ def get_all_ex(call: APICall, company_id, _): company=company_id, query_dict=call_data, query=_hidden_query(call_data), - allow_public=True, + allow_public=request.allow_public, ret_params=ret_params, ) unprepare_from_saved(call, tasks) @@ -291,6 +293,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest): **stop_task( task_id=req_model.task, company_id=company_id, + user_id=call.identity.user, user_name=call.identity.user_name, status_reason=req_model.status_reason, force=req_model.force, @@ -308,6 +311,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest): func=partial( stop_task, company_id=company_id, + user_id=call.identity.user, user_name=call.identity.user_name, status_reason=request.status_reason, force=request.force, @@ -329,7 +333,8 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest): call.result.data_model = UpdateResponse( **set_task_status_from_call( req_model, - company_id, + company_id=company_id, + user_id=call.identity.user, new_status=TaskStatus.stopped, completed=datetime.utcnow(), ) @@ -345,7 +350,8 @@ def started(call: APICall, company_id, req_model: UpdateRequest): res = StartedResponse( **set_task_status_from_call( req_model, - company_id, + company_id=company_id, + user_id=call.identity.user, new_status=TaskStatus.in_progress, min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value ) @@ -359,7 +365,12 @@ def started(call: APICall, company_id, req_model: UpdateRequest): ) def failed(call: APICall, company_id, req_model: UpdateRequest): call.result.data_model = UpdateResponse( - **set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed) + **set_task_status_from_call( + req_model, + company_id=company_id, + user_id=call.identity.user, + new_status=TaskStatus.failed, + ) ) @@ -368,7 +379,11 @@ def failed(call: APICall, company_id, req_model: UpdateRequest): ) def close(call: APICall, company_id, req_model: UpdateRequest): call.result.data_model = UpdateResponse( - **set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed) + **set_task_status_from_call( + req_model, + company_id=company_id, + user_id=call.identity.user, + new_status=TaskStatus.closed) ) @@ -580,7 +595,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest): company_id=company_id, id=task_id, partial_update_dict=partial_update_dict, - injected_update=dict(last_change=datetime.utcnow()), + injected_update=dict( + last_change=datetime.utcnow(), last_changed_by=call.identity.user, + ), ) if updated_count: new_project = updated_fields.get("project", task.project) @@ -613,7 +630,11 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques raise errors.bad_request.MissingTaskFields( "Task has no script field", task=task.id ) - res = update_task(task, update_cmds=dict(script__requirements=requirements)) + res = update_task( + task, + user_id=call.identity.user, + update_cmds=dict(script__requirements=requirements), + ) call.result.data_model = UpdateResponse(updated=res) if res: call.result.data_model.fields = {"script.requirements": requirements} @@ -648,7 +669,9 @@ def update_batch(call: APICall, company_id, _): partial_update_dict = Task.get_safe_update_dict(fields) if not partial_update_dict: continue - partial_update_dict.update(last_change=now) + partial_update_dict.update( + last_change=now, last_changed_by=call.identity.user, + ) update_op = UpdateOne( {"_id": id, "company": company_id}, {"$set": partial_update_dict} ) @@ -725,7 +748,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): } if fixed_fields: now = datetime.utcnow() - last_change = dict(last_change=now) + last_change = dict(last_change=now, last_changed_by=call.identity.user) if not set(fields).issubset(Task.user_set_allowed()): last_change.update(last_update=now) fields.update(**last_change) @@ -762,6 +785,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest call.result.data = { "updated": HyperParams.edit_params( company_id, + user_id=call.identity.user, task_id=request.task, hyperparams=request.hyperparams, replace_hyperparams=request.replace_hyperparams, @@ -775,6 +799,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq call.result.data = { "deleted": HyperParams.delete_params( company_id, + user_id=call.identity.user, task_id=request.task, hyperparams=request.hyperparams, force=request.force, @@ -819,6 +844,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ call.result.data = { "updated": HyperParams.edit_configuration( company_id, + user_id=call.identity.user, task_id=request.task, configuration=request.configuration, replace_configuration=request.replace_configuration, @@ -834,6 +860,7 @@ def delete_configuration( call.result.data = { "deleted": HyperParams.delete_configuration( company_id, + user_id=call.identity.user, task_id=request.task, configuration=request.configuration, force=request.force, @@ -850,6 +877,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest): queued, res = enqueue_task( task_id=request.task, company_id=company_id, + user_id=call.identity.user, queue_id=request.queue, status_message=request.status_message, status_reason=request.status_reason, @@ -874,6 +902,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): func=partial( enqueue_task, company_id=company_id, + user_id=call.identity.user, queue_id=request.queue, status_message=request.status_message, status_reason=request.status_reason, @@ -908,6 +937,7 @@ def dequeue(call: APICall, company_id, request: UpdateRequest): dequeued, res = dequeue_task( task_id=request.task, company_id=company_id, + user_id=call.identity.user, status_message=request.status_message, status_reason=request.status_reason, ) @@ -924,6 +954,7 @@ def dequeue_many(call: APICall, company_id, request: TaskBatchRequest): func=partial( dequeue_task, company_id=company_id, + user_id=call.identity.user, status_message=request.status_message, status_reason=request.status_reason, ), @@ -1019,6 +1050,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): for task in tasks: archived += archive_task( company_id=company_id, + user_id=call.identity.user, task=task, status_message=request.status_message, status_reason=request.status_reason, @@ -1037,6 +1069,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest): func=partial( archive_task, company_id=company_id, + user_id=call.identity.user, status_message=request.status_message, status_reason=request.status_reason, ), @@ -1058,6 +1091,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): func=partial( unarchive_task, company_id=company_id, + user_id=call.identity.user, status_message=request.status_message, status_reason=request.status_reason, ), @@ -1136,6 +1170,7 @@ def publish(call: APICall, company_id, request: PublishRequest): updates = publish_task( task_id=request.task, company_id=company_id, + user_id=call.identity.user, force=request.force, publish_model_func=ModelBLL.publish_model if request.publish_model else None, status_reason=request.status_reason, @@ -1154,6 +1189,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest): func=partial( publish_task, company_id=company_id, + user_id=call.identity.user, force=request.force, publish_model_func=ModelBLL.publish_model if request.publish_model @@ -1180,7 +1216,8 @@ def completed(call: APICall, company_id, request: CompletedRequest): res = CompletedResponse( **set_task_status_from_call( request, - company_id, + company_id=company_id, + user_id=call.identity.user, new_status=TaskStatus.completed, completed=datetime.utcnow(), ) @@ -1190,6 +1227,7 @@ def completed(call: APICall, company_id, request: CompletedRequest): publish_res = publish_task( task_id=request.task, company_id=company_id, + user_id=call.identity.user, force=request.force, publish_model_func=ModelBLL.publish_model, status_reason=request.status_reason, @@ -1221,6 +1259,7 @@ def add_or_update_artifacts( call.result.data = { "updated": Artifacts.add_or_update_artifacts( company_id=company_id, + user_id=call.identity.user, task_id=request.task, artifacts=request.artifacts, force=True, @@ -1237,6 +1276,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest) call.result.data = { "deleted": Artifacts.delete_artifacts( company_id=company_id, + user_id=call.identity.user, task_id=request.task, artifact_ids=request.artifacts, force=True, @@ -1304,7 +1344,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ @endpoint("tasks.delete_models", min_version="2.13") -def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest): +def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest): task = get_task_for_update(company_id=company_id, task_id=request.task, force=True) delete_names = { @@ -1317,5 +1357,5 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest): if names } - updated = task.update(last_change=datetime.utcnow(), **commands,) + updated = task.update(last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,) return {"updated": updated}