Merge pipeline parameters with original task hyperparameters

This commit is contained in:
clearml 2024-12-05 19:11:36 +02:00
parent fa41e14625
commit 58b748ddf3
3 changed files with 51 additions and 17 deletions

View File

@ -39,6 +39,7 @@ from apiserver.database.utils import (
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.utilities.dicts import nested_set
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@ -163,18 +164,35 @@ class TaskBLL:
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
hyperparams_overrides: Optional[dict] = None,
configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
params_dict = {}
if hyperparams:
params_dict["hyperparams"] = hyperparams
elif hyperparams_overrides:
updated_hyperparams = {
sec: {k: value for k, value in sec_data.items()}
for sec, sec_data in (task.hyperparams or {}).items()
}
for section, section_data in hyperparams_overrides.items():
for key, value in section_data.items():
nested_set(updated_hyperparams, (section, key), value)
params_dict["hyperparams"] = updated_hyperparams
if configuration:
params_dict["configuration"] = configuration
elif configuration_overrides:
updated_configuration = {
k: value
for k, value in (task.configuration or {}).items()
}
for key, value in configuration_overrides.items():
updated_configuration[key] = value
params_dict["configuration"] = updated_configuration
now = datetime.utcnow()
if input_models:

View File

@ -109,7 +109,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=hyperparams,
hyperparams_overrides=hyperparams,
)
_update_task_name(task)

View File

@ -5,6 +5,18 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService):
task_hyperparams = {
"properties":
{
"version": {
"section": "properties",
"name": "version",
"type": "str",
"value": "3.2"
}
}
}
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
@ -82,14 +94,17 @@ class TestPipelines(TestService):
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
pipeline.hyperparams,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
"Args": {
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
**self.task_hyperparams,
},
)
@ -124,6 +139,7 @@ class TestPipelines(TestService):
type="controller",
project=project,
system_tags=["pipeline"],
hyperparams=self.task_hyperparams,
),
)