mirror of
https://github.com/clearml/clearml-server
synced 2025-04-10 15:55:27 +00:00
Fix tasks.clone
This commit is contained in:
parent
df334d083e
commit
29c792d459
@ -111,8 +111,8 @@ class CloneRequest(TaskRequest):
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_hyperparams = DictField()
|
||||
new_configuration = DictField()
|
||||
new_task_hyperparams = DictField()
|
||||
new_task_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
new_project_name = StringField()
|
||||
|
@ -29,6 +29,7 @@ from apiserver.database.model.task.task import (
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
@ -179,7 +180,8 @@ class TaskBLL:
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
) -> Task:
|
||||
) -> Tuple[Task, dict]:
|
||||
validate_tags(tags, system_tags)
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
@ -215,18 +217,20 @@ class TaskBLL:
|
||||
if a.get("mode") != ArtifactModes.output
|
||||
}
|
||||
|
||||
new_project_data = None
|
||||
if not project and new_project_name:
|
||||
# Use a project with the provided name, or create a new project
|
||||
project = project_bll.find_or_create(
|
||||
project = ProjectBLL.find_or_create(
|
||||
project_name=new_project_name,
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
description="Auto-generated while cloning",
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "clone task"):
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
@ -270,7 +274,7 @@ class TaskBLL:
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
||||
return new_task
|
||||
return new_task, new_project_data
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
@ -280,6 +284,11 @@ class TaskBLL:
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
"""
|
||||
Validate task properties according to the flag
|
||||
Task project is always checked for being writable
|
||||
in order to disable the modification of public projects
|
||||
"""
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
@ -289,12 +298,10 @@ class TaskBLL:
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
if task.project:
|
||||
project = Project.get_for_writing(company=task.company, id=task.project)
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
@ -781,6 +781,18 @@ clone {
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
new_project {
|
||||
description: "In case the new_project_name was specified returns the target project details"
|
||||
type: object
|
||||
properties {
|
||||
id: "The ID of the target project"
|
||||
name: "The name of the target project"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
create {
|
||||
|
@ -87,7 +87,6 @@ from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
validate_tags,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
@ -415,12 +414,9 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
||||
)
|
||||
@endpoint("tasks.clone", request_data_model=CloneRequest)
|
||||
def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
validate_tags(request.new_task_tags, request.new_task_system_tags)
|
||||
task = task_bll.clone_task(
|
||||
task, new_project = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
@ -430,13 +426,16 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
project=request.new_task_project,
|
||||
tags=request.new_task_tags,
|
||||
system_tags=request.new_task_system_tags,
|
||||
hyperparams=request.new_hyperparams,
|
||||
configuration=request.new_configuration,
|
||||
hyperparams=request.new_task_hyperparams,
|
||||
configuration=request.new_task_configuration,
|
||||
execution_overrides=request.execution_overrides,
|
||||
validate_references=request.validate_references,
|
||||
new_project_name=request.new_project_name,
|
||||
)
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
call.result.data = {
|
||||
"id": task.id,
|
||||
**({"new_project": new_project} if new_project else {}),
|
||||
}
|
||||
|
||||
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
@ -468,9 +467,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
return UpdateResponse(updated=0)
|
||||
|
||||
updated_count, updated_fields = Task.safe_update(
|
||||
company_id=company_id,
|
||||
id=task_id,
|
||||
partial_update_dict=partial_update_dict,
|
||||
company_id=company_id, id=task_id, partial_update_dict=partial_update_dict,
|
||||
)
|
||||
if updated_count:
|
||||
new_project = updated_fields.get("project", task.project)
|
||||
@ -879,7 +876,13 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
).execute(started=None, completed=None, published=None, active_duration=None, **updates)
|
||||
).execute(
|
||||
started=None,
|
||||
completed=None,
|
||||
published=None,
|
||||
active_duration=None,
|
||||
**updates,
|
||||
)
|
||||
)
|
||||
|
||||
# do not return artifacts since they are not serializable
|
||||
@ -906,10 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
for task in tasks:
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
request.status_message,
|
||||
request.status_reason,
|
||||
task, company_id, request.status_message, request.status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
|
@ -18,10 +18,14 @@ class TestMoveUnderProject(TestService):
|
||||
|
||||
# task clone
|
||||
p2_name = "project_for_clone"
|
||||
task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id
|
||||
res = self.api.tasks.clone(task=task, new_project_name=p2_name)
|
||||
task2 = res.id
|
||||
project_data = res.new_project
|
||||
self.assertTrue(project_data.id)
|
||||
self.assertEqual(p2_name, project_data.name)
|
||||
tasks = self.api.tasks.get_all_ex(id=[task2]).tasks
|
||||
project2 = tasks[0].project.id
|
||||
self.assertTrue(project2)
|
||||
self.assertEqual(project_data.id, project2)
|
||||
projects = self.api.projects.get_all_ex(id=[project2]).projects
|
||||
self.assertEqual(p2_name, projects[0].name)
|
||||
self.api.projects.delete(project=project2, force=True)
|
||||
|
Loading…
Reference in New Issue
Block a user