diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index ce4a617..29a5376 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -121,6 +121,7 @@ class CloneRequest(TaskRequest): new_task_project = StringField() new_task_hyperparams = DictField() new_task_configuration = DictField() + new_task_container = DictField() new_task_input_models = ListField([TaskInputModel]) execution_overrides = DictField() validate_references = BoolField(default=False) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index c31cab6..3f859cc 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -37,9 +37,11 @@ from .sub_projects import ( _get_sub_projects, _ids_with_children, _ids_with_parents, + _get_project_depth, ) log = config.logger(__file__) +max_depth = config.get("services.projects.sub_projects.max_depth", 10) class ProjectBLL: @@ -60,6 +62,15 @@ class ProjectBLL: source = Project.get(company, source_id) destination = Project.get(company, destination_id) + children = _get_sub_projects( + [source.id], _only=("id", "name", "parent", "path") + )[source.id] + cls.validate_projects_depth( + projects=children, + old_parent_depth=len(source.path) + 1, + new_parent_depth=len(destination.path) + 1, + ) + moved_entities = 0 for entity_type in (Task, Model): moved_entities += entity_type.objects( @@ -70,7 +81,11 @@ class ProjectBLL: moved_sub_projects = 0 for child in Project.objects(company=company, parent=source_id): - _reposition_project_with_children(project=child, parent=destination) + _reposition_project_with_children( + project=child, + children=[c for c in children if c.parent == child.id], + parent=destination, + ) moved_sub_projects += 1 affected = {source.id, *(source.path or [])} @@ -82,6 +97,15 @@ class ProjectBLL: return moved_entities, moved_sub_projects, affected + @staticmethod + def validate_projects_depth( + projects: Sequence[Project], old_parent_depth: int, new_parent_depth: int + ): + for current in projects: + current_depth = len(current.path) + 1 + if current_depth - old_parent_depth + new_parent_depth > max_depth: + raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth) + @classmethod def move_project( cls, company: str, user: str, project_id: str, new_location: str @@ -100,6 +124,16 @@ class ProjectBLL: if old_parent_id else None ) + + children = _get_sub_projects([project.id], _only=("id", "name", "path"))[ + project.id + ] + cls.validate_projects_depth( + projects=[project, *children], + old_parent_depth=len(project.path), + new_parent_depth=_get_project_depth(new_location), + ) + new_parent = _ensure_project(company=company, user=user, name=new_location) new_parent_id = new_parent.id if new_parent else None if old_parent_id == new_parent_id: @@ -107,7 +141,9 @@ class ProjectBLL: location=new_parent.name if new_parent else "" ) - moved = _reposition_project_with_children(project, parent=new_parent) + moved = _reposition_project_with_children( + project, children=children, parent=new_parent + ) now = datetime.utcnow() affected = set() @@ -138,7 +174,12 @@ class ProjectBLL: if new_name: old_name = project.name project.name = new_name - _update_subproject_names(project=project, old_name=old_name) + children = _get_sub_projects( + [project.id], _only=("id", "name", "path") + )[project.id] + _update_subproject_names( + project=project, children=children, old_name=old_name + ) return updated @@ -157,6 +198,9 @@ class ProjectBLL: Create a new project. Returns project ID """ + if _get_project_depth(name) > max_depth: + raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth) + name, location = _validate_project_name(name) now = datetime.utcnow() project = Project( diff --git a/apiserver/bll/project/sub_projects.py b/apiserver/bll/project/sub_projects.py index 06d33ac..51f7736 100644 --- a/apiserver/bll/project/sub_projects.py +++ b/apiserver/bll/project/sub_projects.py @@ -4,26 +4,24 @@ from typing import Tuple, Optional, Sequence, Mapping from apiserver import database from apiserver.apierrors import errors -from apiserver.config_repo import config from apiserver.database.model.project import Project name_separator = "/" -max_depth = config.get("services.projects.sub_projects.max_depth", 10) + + +def _get_project_depth(project_name: str) -> int: + return len(list(filter(None, project_name.split(name_separator)))) def _validate_project_name(project_name: str) -> Tuple[str, str]: """ Remove redundant '/' characters. Ensure that the project name is not empty - and path to it is not larger then max_depth parameter. Return the cleaned up project name and location """ name_parts = list(filter(None, project_name.split(name_separator))) if not name_parts: raise errors.bad_request.InvalidProjectName(name=project_name) - if len(name_parts) > max_depth: - raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth) - return name_separator.join(name_parts), name_separator.join(name_parts[:-1]) @@ -135,6 +133,7 @@ def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]: def _update_subproject_names( project: Project, + children: Sequence[Project], old_name: str, update_path: bool = False, old_path: Sequence[str] = None, @@ -143,9 +142,8 @@ def _update_subproject_names( Update sub project names when the base project name changes Optionally update the paths """ - child_projects = _get_sub_projects(project_ids=[project.id], _only=("id", "name")) updated = 0 - for child in child_projects[project.id]: + for child in children: child_suffix = name_separator.join( child.name.split(name_separator)[len(old_name.split(name_separator)) :] ) @@ -157,7 +155,9 @@ def _update_subproject_names( return updated -def _reposition_project_with_children(project: Project, parent: Project) -> int: +def _reposition_project_with_children( + project: Project, children: Sequence[Project], parent: Project +) -> int: new_location = parent.name if parent else None old_name = project.name old_path = project.path @@ -167,6 +167,10 @@ def _reposition_project_with_children(project: Project, parent: Project) -> int: _save_under_parent(project, parent=parent) moved = 1 + _update_subproject_names( - project=project, old_name=old_name, update_path=True, old_path=old_path + project=project, + children=children, + old_name=old_name, + update_path=True, + old_path=old_path, ) return moved diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index adf70c0..5cc457d 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -177,6 +177,7 @@ class TaskBLL: system_tags: Optional[Sequence[str]] = None, hyperparams: Optional[dict] = None, configuration: Optional[dict] = None, + container: Optional[dict] = None, execution_overrides: Optional[dict] = None, input_models: Optional[Sequence[TaskInputModel]] = None, validate_references: bool = False, @@ -204,6 +205,11 @@ class TaskBLL: if not input_models and execution_model: input_models = [ModelItem(model=execution_model, name="input")] + docker_cmd = execution_overrides.pop("docker_cmd", None) + if not container and docker_cmd: + image, _, arguments = docker_cmd.partition(" ") + container = {"image": image, "arguments": arguments} + artifacts_prepare_for_save({"execution": execution_overrides}) params_dict["execution"] = {} @@ -272,6 +278,7 @@ class TaskBLL: if task.output else None, models=Models(input=input_models or task.models.input), + container=container or task.container, execution=execution_dict, configuration=params_dict.get("configuration") or task.configuration, hyperparams=params_dict.get("hyperparams") or task.hyperparams, diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index b86c198..6e3bb02 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -241,6 +241,19 @@ _definitions { } } } + container { + type: object + properties { + image { + type: string + description: "Docker image" + } + arguments { + type: string + description: "Docker command arguments" + } + } + } execution { type: object properties { @@ -488,6 +501,10 @@ _definitions { description: "Task execution params" "$ref": "#/definitions/execution" } + container { + description: "Docker container parameters" + "$ref": "#/definitions/container" + } models { description: "Task models" "$ref": "#/definitions/task_models" @@ -879,6 +896,11 @@ clone { type: array items {"$ref": "#/definitions/task_model_item"} } + new_task_container { + description: "The docker container properties for the new task. If not provided then taken from the original task" + type: object + additionalProperties { type: string } + } } } } @@ -1052,6 +1074,10 @@ create { description: "Task models" "$ref": "#/definitions/task_models" } + container { + description: "Docker container parameters" + "$ref": "#/definitions/container" + } } } } @@ -1136,6 +1162,10 @@ validate { description: "Task models" "$ref": "#/definitions/task_models" } + container { + description: "Docker container parameters" + "$ref": "#/definitions/container" + } } } } @@ -1321,6 +1351,10 @@ edit { description: "Task models" "$ref": "#/definitions/task_models" } + container { + description: "Docker container parameters" + "$ref": "#/definitions/container" + } } } } diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 86d6878..ce5ec8f 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -336,6 +336,7 @@ create_fields = { "project": None, "input": None, "models": None, + "container": None, "output_dest": None, "execution": None, "hyperparams": None, @@ -478,6 +479,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest): system_tags=request.new_task_system_tags, hyperparams=request.new_task_hyperparams, configuration=request.new_task_configuration, + container=request.new_task_container, execution_overrides=request.execution_overrides, input_models=request.new_task_input_models, validate_references=request.validate_references,