diff --git a/server/apimodels/tasks.py b/server/apimodels/tasks.py index a025aeb..604bcaa 100644 --- a/server/apimodels/tasks.py +++ b/server/apimodels/tasks.py @@ -100,6 +100,7 @@ class CloneRequest(TaskRequest): new_task_parent = StringField() new_task_project = StringField() execution_overrides = DictField() + validate_references = BoolField(default=False) class AddOrUpdateArtifactsRequest(TaskRequest): diff --git a/server/bll/project/__init__.py b/server/bll/project/__init__.py new file mode 100644 index 0000000..0b8ab93 --- /dev/null +++ b/server/bll/project/__init__.py @@ -0,0 +1 @@ +from .project_bll import ProjectBLL diff --git a/server/bll/project/project_bll.py b/server/bll/project/project_bll.py new file mode 100644 index 0000000..dcb577d --- /dev/null +++ b/server/bll/project/project_bll.py @@ -0,0 +1,33 @@ +from typing import Sequence, Optional + +from mongoengine import Q + +from config import config +from database.model.model import Model +from database.model.task.task import Task +from timing_context import TimingContext + +log = config.logger(__file__) + + +class ProjectBLL: + @classmethod + def get_active_users( + cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None + ) -> set: + """ + Get the set of user ids that created tasks/models in the given projects + If project_ids is empty then all projects are examined + If user_ids are passed then only subset of these users is returned + """ + with TimingContext("mongo", "active_users_in_projects"): + res = set() + query = Q(company=company) + if project_ids: + query &= Q(project__in=project_ids) + if user_ids: + query &= Q(user__in=user_ids) + for cls_ in (Task, Model): + res |= set(cls_.objects(query).distinct(field="user")) + + return res diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index f7896d4..9ac4fdc 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -164,9 +164,11 @@ class TaskBLL(object): tags: Optional[Sequence[str]] = None, system_tags: Optional[Sequence[str]] = None, execution_overrides: Optional[dict] = None, + validate_references: bool = False, ) -> Task: task = cls.get_by_id(company_id=company_id, task_id=task_id) execution_dict = task.execution.to_proper_dict() if task.execution else {} + execution_model_overriden = False if execution_overrides: parameters = execution_overrides.get("parameters") if parameters is not None: @@ -174,6 +176,8 @@ class TaskBLL(object): ParameterKeyEscaper.escape(k): v for k, v in parameters.items() } execution_dict = deep_merge(execution_dict, execution_overrides) + execution_model_overriden = execution_overrides.get("model") is not None + artifacts = execution_dict.get("artifacts") if artifacts: execution_dict["artifacts"] = [ @@ -201,26 +205,42 @@ class TaskBLL(object): else None, execution=execution_dict, ) - cls.validate(new_task) + cls.validate( + new_task, + validate_model=validate_references or execution_model_overriden, + validate_parent=validate_references or parent, + validate_project=validate_references or project, + ) new_task.save() return new_task @classmethod - def validate(cls, task: Task): - assert isinstance(task, Task) - - if task.parent and not Task.get( - company=task.company, id=task.parent, _only=("id",), include_public=True + def validate( + cls, + task: Task, + validate_model=True, + validate_parent=True, + validate_project=True, + ): + if ( + validate_parent + and task.parent + and not Task.get( + company=task.company, id=task.parent, _only=("id",), include_public=True + ) ): raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) - if task.project and not Project.get_for_writing( - company=task.company, id=task.project + if ( + validate_project + and task.project + and not Project.get_for_writing(company=task.company, id=task.project) ): raise errors.bad_request.InvalidProjectId(id=task.project) - cls.validate_execution_model(task) + if validate_model: + cls.validate_execution_model(task) @staticmethod def get_unique_metric_variants(company_id, project_ids=None): diff --git a/server/schema/services/events.conf b/server/schema/services/events.conf index cbd28b5..ce38e21 100644 --- a/server/schema/services/events.conf +++ b/server/schema/services/events.conf @@ -530,59 +530,59 @@ } } } - "2.7" { - description: "Get 'log' events for this task" - request { - type: object - required: [ - task - ] - properties { - task { - type: string - description: "Task ID" - } - batch_size { - type: integer - description: "The amount of log events to return" - } - navigate_earlier { - type: boolean - description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True" - } - refresh { - type: boolean - description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)" - } - scroll_id { - type: string - description: "Scroll ID of previous call (used for getting more results)" - } - } - } - response { - type: object - properties { - events { - type: array - items { type: object } - description: "Log items list" - } - returned { - type: integer - description: "Number of log events returned" - } - total { - type: number - description: "Total number of log events available for this query" - } - scroll_id { - type: string - description: "Scroll ID for getting more results" - } - } - } - } +// "2.7" { +// description: "Get 'log' events for this task" +// request { +// type: object +// required: [ +// task +// ] +// properties { +// task { +// type: string +// description: "Task ID" +// } +// batch_size { +// type: integer +// description: "The amount of log events to return" +// } +// navigate_earlier { +// type: boolean +// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True" +// } +// refresh { +// type: boolean +// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)" +// } +// scroll_id { +// type: string +// description: "Scroll ID of previous call (used for getting more results)" +// } +// } +// } +// response { +// type: object +// properties { +// events { +// type: array +// items { type: object } +// description: "Log items list" +// } +// returned { +// type: integer +// description: "Number of log events returned" +// } +// total { +// type: number +// description: "Total number of log events available for this query" +// } +// scroll_id { +// type: string +// description: "Scroll ID for getting more results" +// } +// } +// } +// } } get_task_events { "2.1" { diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index b4cffca..299e026 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -591,6 +591,10 @@ clone { description: "The execution params for the cloned task. The params not specified are taken from the original task" "$ref": "#/definitions/execution" } + validate_references { + description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false." + type: boolean + } } } response { diff --git a/server/services/events.py b/server/services/events.py index c4e19fe..71916e0 100644 --- a/server/services/events.py +++ b/server/services/events.py @@ -94,26 +94,27 @@ def get_task_log_v1_7(call, company_id, req_model): ) -@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest) -def get_task_log(call, company_id, req_model: LogEventsRequest): - task_id = req_model.task - task_bll.assert_exists(company_id, task_id, allow_public=True) - - res = event_bll.log_events_iterator.get_task_events( - company_id=company_id, - task_id=task_id, - batch_size=req_model.batch_size, - navigate_earlier=req_model.navigate_earlier, - refresh=req_model.refresh, - state_id=req_model.scroll_id, - ) - - call.result.data = dict( - events=res.events, - returned=len(res.events), - total=res.total_events, - scroll_id=res.next_scroll_id, - ) +# uncomment this once the front end is ready +# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest) +# def get_task_log(call, company_id, req_model: LogEventsRequest): +# task_id = req_model.task +# task_bll.assert_exists(company_id, task_id, allow_public=True) +# +# res = event_bll.log_events_iterator.get_task_events( +# company_id=company_id, +# task_id=task_id, +# batch_size=req_model.batch_size, +# navigate_earlier=req_model.navigate_earlier, +# refresh=req_model.refresh, +# state_id=req_model.scroll_id, +# ) +# +# call.result.data = dict( +# events=res.events, +# returned=len(res.events), +# total=res.total_events, +# scroll_id=res.next_scroll_id, +# ) @endpoint("events.download_task_log", required_fields=["task"]) diff --git a/server/services/tasks.py b/server/services/tasks.py index f99af1e..4d1e09d 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -361,6 +361,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest): tags=request.new_task_tags, system_tags=request.new_task_system_tags, execution_overrides=request.execution_overrides, + validate_references=request.validate_references, ) call.result.data_model = IdResponse(id=task.id) diff --git a/server/tests/automated/test_task_events.py b/server/tests/automated/test_task_events.py index 633a7da..d88eb7f 100644 --- a/server/tests/automated/test_task_events.py +++ b/server/tests/automated/test_task_events.py @@ -186,6 +186,7 @@ class TestTaskEvents(TestService): self.assertEqual(len(res.events), 1) def test_task_logs(self): + # this test will fail until the new api is uncommented task = self._temp_task() timestamp = es_factory.get_timestamp_millis() events = [ diff --git a/server/tests/automated/test_tasks_edit.py b/server/tests/automated/test_tasks_edit.py index 078ba20..5e9cf18 100644 --- a/server/tests/automated/test_tasks_edit.py +++ b/server/tests/automated/test_tasks_edit.py @@ -74,13 +74,13 @@ class TestTasksEdit(TestService): new_name = "new test" new_tags = ["by"] execution_overrides = dict(framework="Caffe") - new_task_id = self.api.tasks.clone( + new_task_id = self._clone_task( task=task, new_task_name=new_name, new_task_tags=new_tags, execution_overrides=execution_overrides, new_task_parent=task, - ).id + ) new_task = self.api.tasks.get_by_id(task=new_task_id).task self.assertEqual(new_task.name, new_name) self.assertEqual(new_task.type, "testing") @@ -91,3 +91,32 @@ class TestTasksEdit(TestService): self.assertEqual(new_task.execution.parameters, execution["parameters"]) self.assertEqual(new_task.execution.framework, execution_overrides["framework"]) self.assertEqual(new_task.system_tags, []) + + def test_model_check_in_clone(self): + model = self.new_model() + task = self.new_task(execution=dict(model=model)) + + # task with deleted model still can be copied + self.api.models.delete(model=model, force=True) + self._clone_task(task=task, new_task_name="clone test") + + # unless check for refs is done + with self.api.raises(InvalidModelId): + self._clone_task( + task=task, new_task_name="clone test2", validate_references=True + ) + + # if the model is overriden then it is always checked + with self.api.raises(InvalidModelId): + self._clone_task( + task=task, + new_task_name="clone test3", + execution_overrides=dict(model="not existing"), + ) + + def _clone_task(self, task, **kwargs): + new_task = self.api.tasks.clone(task=task, **kwargs).id + self.defer( + self.api.tasks.delete, task=new_task, move_to_trash=False, force=True + ) + return new_task