Fix Task.container backwards-compatibility

Fix sub-projects
This commit is contained in:
allegroai 2021-05-03 17:49:48 +03:00
parent dadb996d22
commit 2e7f418ee2
6 changed files with 105 additions and 13 deletions

View File

@ -121,6 +121,7 @@ class CloneRequest(TaskRequest):
new_task_project = StringField() new_task_project = StringField()
new_task_hyperparams = DictField() new_task_hyperparams = DictField()
new_task_configuration = DictField() new_task_configuration = DictField()
new_task_container = DictField()
new_task_input_models = ListField([TaskInputModel]) new_task_input_models = ListField([TaskInputModel])
execution_overrides = DictField() execution_overrides = DictField()
validate_references = BoolField(default=False) validate_references = BoolField(default=False)

View File

@ -37,9 +37,11 @@ from .sub_projects import (
_get_sub_projects, _get_sub_projects,
_ids_with_children, _ids_with_children,
_ids_with_parents, _ids_with_parents,
_get_project_depth,
) )
log = config.logger(__file__) log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
class ProjectBLL: class ProjectBLL:
@ -60,6 +62,15 @@ class ProjectBLL:
source = Project.get(company, source_id) source = Project.get(company, source_id)
destination = Project.get(company, destination_id) destination = Project.get(company, destination_id)
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0 moved_entities = 0
for entity_type in (Task, Model): for entity_type in (Task, Model):
moved_entities += entity_type.objects( moved_entities += entity_type.objects(
@ -70,7 +81,11 @@ class ProjectBLL:
moved_sub_projects = 0 moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id): for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(project=child, parent=destination) _reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1 moved_sub_projects += 1
affected = {source.id, *(source.path or [])} affected = {source.id, *(source.path or [])}
@ -82,6 +97,15 @@ class ProjectBLL:
return moved_entities, moved_sub_projects, affected return moved_entities, moved_sub_projects, affected
@staticmethod
def validate_projects_depth(
projects: Sequence[Project], old_parent_depth: int, new_parent_depth: int
):
for current in projects:
current_depth = len(current.path) + 1
if current_depth - old_parent_depth + new_parent_depth > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
@classmethod @classmethod
def move_project( def move_project(
cls, company: str, user: str, project_id: str, new_location: str cls, company: str, user: str, project_id: str, new_location: str
@ -100,6 +124,16 @@ class ProjectBLL:
if old_parent_id if old_parent_id
else None else None
) )
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
)
new_parent = _ensure_project(company=company, user=user, name=new_location) new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id: if old_parent_id == new_parent_id:
@ -107,7 +141,9 @@ class ProjectBLL:
location=new_parent.name if new_parent else "" location=new_parent.name if new_parent else ""
) )
moved = _reposition_project_with_children(project, parent=new_parent) moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
now = datetime.utcnow() now = datetime.utcnow()
affected = set() affected = set()
@ -138,7 +174,12 @@ class ProjectBLL:
if new_name: if new_name:
old_name = project.name old_name = project.name
project.name = new_name project.name = new_name
_update_subproject_names(project=project, old_name=old_name) children = _get_sub_projects(
[project.id], _only=("id", "name", "path")
)[project.id]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
return updated return updated
@ -157,6 +198,9 @@ class ProjectBLL:
Create a new project. Create a new project.
Returns project ID Returns project ID
""" """
if _get_project_depth(name) > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
name, location = _validate_project_name(name) name, location = _validate_project_name(name)
now = datetime.utcnow() now = datetime.utcnow()
project = Project( project = Project(

View File

@ -4,26 +4,24 @@ from typing import Tuple, Optional, Sequence, Mapping
from apiserver import database from apiserver import database
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
name_separator = "/" name_separator = "/"
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
def _get_project_depth(project_name: str) -> int:
return len(list(filter(None, project_name.split(name_separator))))
def _validate_project_name(project_name: str) -> Tuple[str, str]: def _validate_project_name(project_name: str) -> Tuple[str, str]:
""" """
Remove redundant '/' characters. Ensure that the project name is not empty Remove redundant '/' characters. Ensure that the project name is not empty
and path to it is not larger then max_depth parameter.
Return the cleaned up project name and location Return the cleaned up project name and location
""" """
name_parts = list(filter(None, project_name.split(name_separator))) name_parts = list(filter(None, project_name.split(name_separator)))
if not name_parts: if not name_parts:
raise errors.bad_request.InvalidProjectName(name=project_name) raise errors.bad_request.InvalidProjectName(name=project_name)
if len(name_parts) > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
return name_separator.join(name_parts), name_separator.join(name_parts[:-1]) return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
@ -135,6 +133,7 @@ def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
def _update_subproject_names( def _update_subproject_names(
project: Project, project: Project,
children: Sequence[Project],
old_name: str, old_name: str,
update_path: bool = False, update_path: bool = False,
old_path: Sequence[str] = None, old_path: Sequence[str] = None,
@ -143,9 +142,8 @@ def _update_subproject_names(
Update sub project names when the base project name changes Update sub project names when the base project name changes
Optionally update the paths Optionally update the paths
""" """
child_projects = _get_sub_projects(project_ids=[project.id], _only=("id", "name"))
updated = 0 updated = 0
for child in child_projects[project.id]: for child in children:
child_suffix = name_separator.join( child_suffix = name_separator.join(
child.name.split(name_separator)[len(old_name.split(name_separator)) :] child.name.split(name_separator)[len(old_name.split(name_separator)) :]
) )
@ -157,7 +155,9 @@ def _update_subproject_names(
return updated return updated
def _reposition_project_with_children(project: Project, parent: Project) -> int: def _reposition_project_with_children(
project: Project, children: Sequence[Project], parent: Project
) -> int:
new_location = parent.name if parent else None new_location = parent.name if parent else None
old_name = project.name old_name = project.name
old_path = project.path old_path = project.path
@ -167,6 +167,10 @@ def _reposition_project_with_children(project: Project, parent: Project) -> int:
_save_under_parent(project, parent=parent) _save_under_parent(project, parent=parent)
moved = 1 + _update_subproject_names( moved = 1 + _update_subproject_names(
project=project, old_name=old_name, update_path=True, old_path=old_path project=project,
children=children,
old_name=old_name,
update_path=True,
old_path=old_path,
) )
return moved return moved

View File

@ -177,6 +177,7 @@ class TaskBLL:
system_tags: Optional[Sequence[str]] = None, system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None, hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None, configuration: Optional[dict] = None,
container: Optional[dict] = None,
execution_overrides: Optional[dict] = None, execution_overrides: Optional[dict] = None,
input_models: Optional[Sequence[TaskInputModel]] = None, input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False, validate_references: bool = False,
@ -204,6 +205,11 @@ class TaskBLL:
if not input_models and execution_model: if not input_models and execution_model:
input_models = [ModelItem(model=execution_model, name="input")] input_models = [ModelItem(model=execution_model, name="input")]
docker_cmd = execution_overrides.pop("docker_cmd", None)
if not container and docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
container = {"image": image, "arguments": arguments}
artifacts_prepare_for_save({"execution": execution_overrides}) artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {} params_dict["execution"] = {}
@ -272,6 +278,7 @@ class TaskBLL:
if task.output if task.output
else None, else None,
models=Models(input=input_models or task.models.input), models=Models(input=input_models or task.models.input),
container=container or task.container,
execution=execution_dict, execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration, configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams, hyperparams=params_dict.get("hyperparams") or task.hyperparams,

View File

@ -241,6 +241,19 @@ _definitions {
} }
} }
} }
container {
type: object
properties {
image {
type: string
description: "Docker image"
}
arguments {
type: string
description: "Docker command arguments"
}
}
}
execution { execution {
type: object type: object
properties { properties {
@ -488,6 +501,10 @@ _definitions {
description: "Task execution params" description: "Task execution params"
"$ref": "#/definitions/execution" "$ref": "#/definitions/execution"
} }
container {
description: "Docker container parameters"
"$ref": "#/definitions/container"
}
models { models {
description: "Task models" description: "Task models"
"$ref": "#/definitions/task_models" "$ref": "#/definitions/task_models"
@ -879,6 +896,11 @@ clone {
type: array type: array
items {"$ref": "#/definitions/task_model_item"} items {"$ref": "#/definitions/task_model_item"}
} }
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: string }
}
} }
} }
} }
@ -1052,6 +1074,10 @@ create {
description: "Task models" description: "Task models"
"$ref": "#/definitions/task_models" "$ref": "#/definitions/task_models"
} }
container {
description: "Docker container parameters"
"$ref": "#/definitions/container"
}
} }
} }
} }
@ -1136,6 +1162,10 @@ validate {
description: "Task models" description: "Task models"
"$ref": "#/definitions/task_models" "$ref": "#/definitions/task_models"
} }
container {
description: "Docker container parameters"
"$ref": "#/definitions/container"
}
} }
} }
} }
@ -1321,6 +1351,10 @@ edit {
description: "Task models" description: "Task models"
"$ref": "#/definitions/task_models" "$ref": "#/definitions/task_models"
} }
container {
description: "Docker container parameters"
"$ref": "#/definitions/container"
}
} }
} }
} }

View File

@ -336,6 +336,7 @@ create_fields = {
"project": None, "project": None,
"input": None, "input": None,
"models": None, "models": None,
"container": None,
"output_dest": None, "output_dest": None,
"execution": None, "execution": None,
"hyperparams": None, "hyperparams": None,
@ -478,6 +479,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
system_tags=request.new_task_system_tags, system_tags=request.new_task_system_tags,
hyperparams=request.new_task_hyperparams, hyperparams=request.new_task_hyperparams,
configuration=request.new_task_configuration, configuration=request.new_task_configuration,
container=request.new_task_container,
execution_overrides=request.execution_overrides, execution_overrides=request.execution_overrides,
input_models=request.new_task_input_models, input_models=request.new_task_input_models,
validate_references=request.validate_references, validate_references=request.validate_references,