mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Fix tasks.clone
This commit is contained in:
parent
df334d083e
commit
29c792d459
@ -111,8 +111,8 @@ class CloneRequest(TaskRequest):
|
|||||||
new_task_system_tags = ListField([str])
|
new_task_system_tags = ListField([str])
|
||||||
new_task_parent = StringField()
|
new_task_parent = StringField()
|
||||||
new_task_project = StringField()
|
new_task_project = StringField()
|
||||||
new_hyperparams = DictField()
|
new_task_hyperparams = DictField()
|
||||||
new_configuration = DictField()
|
new_task_configuration = DictField()
|
||||||
execution_overrides = DictField()
|
execution_overrides = DictField()
|
||||||
validate_references = BoolField(default=False)
|
validate_references = BoolField(default=False)
|
||||||
new_project_name = StringField()
|
new_project_name = StringField()
|
||||||
|
@ -29,6 +29,7 @@ from apiserver.database.model.task.task import (
|
|||||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
|
from apiserver.services.utils import validate_tags
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
from .artifacts import artifacts_prepare_for_save
|
from .artifacts import artifacts_prepare_for_save
|
||||||
@ -179,7 +180,8 @@ class TaskBLL:
|
|||||||
execution_overrides: Optional[dict] = None,
|
execution_overrides: Optional[dict] = None,
|
||||||
validate_references: bool = False,
|
validate_references: bool = False,
|
||||||
new_project_name: str = None,
|
new_project_name: str = None,
|
||||||
) -> Task:
|
) -> Tuple[Task, dict]:
|
||||||
|
validate_tags(tags, system_tags)
|
||||||
params_dict = {
|
params_dict = {
|
||||||
field: value
|
field: value
|
||||||
for field, value in (
|
for field, value in (
|
||||||
@ -215,18 +217,20 @@ class TaskBLL:
|
|||||||
if a.get("mode") != ArtifactModes.output
|
if a.get("mode") != ArtifactModes.output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
new_project_data = None
|
||||||
if not project and new_project_name:
|
if not project and new_project_name:
|
||||||
# Use a project with the provided name, or create a new project
|
# Use a project with the provided name, or create a new project
|
||||||
project = project_bll.find_or_create(
|
project = ProjectBLL.find_or_create(
|
||||||
project_name=new_project_name,
|
project_name=new_project_name,
|
||||||
user=user_id,
|
user=user_id,
|
||||||
company=company_id,
|
company=company_id,
|
||||||
description="Auto-generated while cloning",
|
description="Auto-generated while cloning",
|
||||||
)
|
)
|
||||||
|
new_project_data = {"id": project, "name": new_project_name}
|
||||||
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
|
||||||
with translate_errors_context():
|
with TimingContext("mongo", "clone task"):
|
||||||
new_task = Task(
|
new_task = Task(
|
||||||
id=create_id(),
|
id=create_id(),
|
||||||
user=user_id,
|
user=user_id,
|
||||||
@ -270,7 +274,7 @@ class TaskBLL:
|
|||||||
system_tags=updated_system_tags,
|
system_tags=updated_system_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_task
|
return new_task, new_project_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(
|
def validate(
|
||||||
@ -280,6 +284,11 @@ class TaskBLL:
|
|||||||
validate_parent=True,
|
validate_parent=True,
|
||||||
validate_project=True,
|
validate_project=True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Validate task properties according to the flag
|
||||||
|
Task project is always checked for being writable
|
||||||
|
in order to disable the modification of public projects
|
||||||
|
"""
|
||||||
if (
|
if (
|
||||||
validate_parent
|
validate_parent
|
||||||
and task.parent
|
and task.parent
|
||||||
@ -289,11 +298,9 @@ class TaskBLL:
|
|||||||
):
|
):
|
||||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||||
|
|
||||||
if (
|
if task.project:
|
||||||
validate_project
|
project = Project.get_for_writing(company=task.company, id=task.project)
|
||||||
and task.project
|
if validate_project and not project:
|
||||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
|
||||||
):
|
|
||||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||||
|
|
||||||
if validate_model:
|
if validate_model:
|
||||||
|
@ -781,6 +781,18 @@ clone {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
response {
|
||||||
|
properties {
|
||||||
|
new_project {
|
||||||
|
description: "In case the new_project_name was specified returns the target project details"
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
id: "The ID of the target project"
|
||||||
|
name: "The name of the target project"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
create {
|
create {
|
||||||
|
@ -87,7 +87,6 @@ from apiserver.service_repo import APICall, endpoint
|
|||||||
from apiserver.services.utils import (
|
from apiserver.services.utils import (
|
||||||
conform_tag_fields,
|
conform_tag_fields,
|
||||||
conform_output_tags,
|
conform_output_tags,
|
||||||
validate_tags,
|
|
||||||
)
|
)
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
@ -415,12 +414,9 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
|||||||
call.result.data_model = IdResponse(id=task.id)
|
call.result.data_model = IdResponse(id=task.id)
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint("tasks.clone", request_data_model=CloneRequest)
|
||||||
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
|
||||||
)
|
|
||||||
def clone_task(call: APICall, company_id, request: CloneRequest):
|
def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||||
validate_tags(request.new_task_tags, request.new_task_system_tags)
|
task, new_project = task_bll.clone_task(
|
||||||
task = task_bll.clone_task(
|
|
||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
user_id=call.identity.user,
|
user_id=call.identity.user,
|
||||||
task_id=request.task,
|
task_id=request.task,
|
||||||
@ -430,13 +426,16 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
|||||||
project=request.new_task_project,
|
project=request.new_task_project,
|
||||||
tags=request.new_task_tags,
|
tags=request.new_task_tags,
|
||||||
system_tags=request.new_task_system_tags,
|
system_tags=request.new_task_system_tags,
|
||||||
hyperparams=request.new_hyperparams,
|
hyperparams=request.new_task_hyperparams,
|
||||||
configuration=request.new_configuration,
|
configuration=request.new_task_configuration,
|
||||||
execution_overrides=request.execution_overrides,
|
execution_overrides=request.execution_overrides,
|
||||||
validate_references=request.validate_references,
|
validate_references=request.validate_references,
|
||||||
new_project_name=request.new_project_name,
|
new_project_name=request.new_project_name,
|
||||||
)
|
)
|
||||||
call.result.data_model = IdResponse(id=task.id)
|
call.result.data = {
|
||||||
|
"id": task.id,
|
||||||
|
**({"new_project": new_project} if new_project else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def prepare_update_fields(call: APICall, task, call_data):
|
def prepare_update_fields(call: APICall, task, call_data):
|
||||||
@ -468,9 +467,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
|||||||
return UpdateResponse(updated=0)
|
return UpdateResponse(updated=0)
|
||||||
|
|
||||||
updated_count, updated_fields = Task.safe_update(
|
updated_count, updated_fields = Task.safe_update(
|
||||||
company_id=company_id,
|
company_id=company_id, id=task_id, partial_update_dict=partial_update_dict,
|
||||||
id=task_id,
|
|
||||||
partial_update_dict=partial_update_dict,
|
|
||||||
)
|
)
|
||||||
if updated_count:
|
if updated_count:
|
||||||
new_project = updated_fields.get("project", task.project)
|
new_project = updated_fields.get("project", task.project)
|
||||||
@ -879,7 +876,13 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
|||||||
force=force,
|
force=force,
|
||||||
status_reason="reset",
|
status_reason="reset",
|
||||||
status_message="reset",
|
status_message="reset",
|
||||||
).execute(started=None, completed=None, published=None, active_duration=None, **updates)
|
).execute(
|
||||||
|
started=None,
|
||||||
|
completed=None,
|
||||||
|
published=None,
|
||||||
|
active_duration=None,
|
||||||
|
**updates,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# do not return artifacts since they are not serializable
|
# do not return artifacts since they are not serializable
|
||||||
@ -906,10 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
|||||||
for task in tasks:
|
for task in tasks:
|
||||||
try:
|
try:
|
||||||
TaskBLL.dequeue_and_change_status(
|
TaskBLL.dequeue_and_change_status(
|
||||||
task,
|
task, company_id, request.status_message, request.status_reason,
|
||||||
company_id,
|
|
||||||
request.status_message,
|
|
||||||
request.status_reason,
|
|
||||||
)
|
)
|
||||||
except APIError:
|
except APIError:
|
||||||
# dequeue may fail if the task was not enqueued
|
# dequeue may fail if the task was not enqueued
|
||||||
|
@ -18,10 +18,14 @@ class TestMoveUnderProject(TestService):
|
|||||||
|
|
||||||
# task clone
|
# task clone
|
||||||
p2_name = "project_for_clone"
|
p2_name = "project_for_clone"
|
||||||
task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id
|
res = self.api.tasks.clone(task=task, new_project_name=p2_name)
|
||||||
|
task2 = res.id
|
||||||
|
project_data = res.new_project
|
||||||
|
self.assertTrue(project_data.id)
|
||||||
|
self.assertEqual(p2_name, project_data.name)
|
||||||
tasks = self.api.tasks.get_all_ex(id=[task2]).tasks
|
tasks = self.api.tasks.get_all_ex(id=[task2]).tasks
|
||||||
project2 = tasks[0].project.id
|
project2 = tasks[0].project.id
|
||||||
self.assertTrue(project2)
|
self.assertEqual(project_data.id, project2)
|
||||||
projects = self.api.projects.get_all_ex(id=[project2]).projects
|
projects = self.api.projects.get_all_ex(id=[project2]).projects
|
||||||
self.assertEqual(p2_name, projects[0].name)
|
self.assertEqual(p2_name, projects[0].name)
|
||||||
self.api.projects.delete(project=project2, force=True)
|
self.api.projects.delete(project=project2, force=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user