From 58b748ddf30d8dce976795d96d8bd2db1b1287d0 Mon Sep 17 00:00:00 2001 From: clearml <> Date: Thu, 5 Dec 2024 19:11:36 +0200 Subject: [PATCH] Merge pipeline parameters with original task hyperparameters --- apiserver/bll/task/task_bll.py | 36 +++++++++++++++------ apiserver/services/pipelines.py | 2 +- apiserver/tests/automated/test_pipelines.py | 30 +++++++++++++---- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 3883511..ebfae5c 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -39,6 +39,7 @@ from apiserver.database.utils import ( from apiserver.es_factory import es_factory from apiserver.redis_manager import redman from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict +from apiserver.utilities.dicts import nested_set from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save from .utils import ( @@ -163,18 +164,35 @@ class TaskBLL: input_models: Optional[Sequence[TaskInputModel]] = None, validate_references: bool = False, new_project_name: str = None, + hyperparams_overrides: Optional[dict] = None, + configuration_overrides: Optional[dict] = None, ) -> Tuple[Task, dict]: validate_tags(tags, system_tags) - params_dict = { - field: value - for field, value in ( - ("hyperparams", hyperparams), - ("configuration", configuration), - ) - if value is not None - } + task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) - task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) + params_dict = {} + if hyperparams: + params_dict["hyperparams"] = hyperparams + elif hyperparams_overrides: + updated_hyperparams = { + sec: {k: value for k, value in sec_data.items()} + for sec, sec_data in (task.hyperparams or {}).items() + } + for section, section_data in hyperparams_overrides.items(): + for key, value in section_data.items(): + nested_set(updated_hyperparams, (section, key), value) + params_dict["hyperparams"] = updated_hyperparams + + if configuration: + params_dict["configuration"] = configuration + elif configuration_overrides: + updated_configuration = { + k: value + for k, value in (task.configuration or {}).items() + } + for key, value in configuration_overrides.items(): + updated_configuration[key] = value + params_dict["configuration"] = updated_configuration now = datetime.utcnow() if input_models: diff --git a/apiserver/services/pipelines.py b/apiserver/services/pipelines.py index a735db2..1927559 100644 --- a/apiserver/services/pipelines.py +++ b/apiserver/services/pipelines.py @@ -109,7 +109,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest company_id=company_id, user_id=call.identity.user, task_id=request.task, - hyperparams=hyperparams, + hyperparams_overrides=hyperparams, ) _update_task_name(task) diff --git a/apiserver/tests/automated/test_pipelines.py b/apiserver/tests/automated/test_pipelines.py index 3c1e0f9..ec75144 100644 --- a/apiserver/tests/automated/test_pipelines.py +++ b/apiserver/tests/automated/test_pipelines.py @@ -5,6 +5,18 @@ from apiserver.tests.automated import TestService class TestPipelines(TestService): + task_hyperparams = { + "properties": + { + "version": { + "section": "properties", + "name": "version", + "type": "str", + "value": "3.2" + } + } + } + def test_controller_operations(self): task_name = "pipelines test" project, task = self._temp_project_and_task(name=task_name) @@ -82,14 +94,17 @@ class TestPipelines(TestService): self.assertEqual(pipeline.status, "queued") self.assertEqual(pipeline.project.id, project) self.assertEqual( - pipeline.hyperparams.Args, + pipeline.hyperparams, { - a["name"]: { - "section": "Args", - "name": a["name"], - "value": a["value"], - } - for a in args + "Args": { + a["name"]: { + "section": "Args", + "name": a["name"], + "value": a["value"], + } + for a in args + }, + **self.task_hyperparams, }, ) @@ -124,6 +139,7 @@ class TestPipelines(TestService): type="controller", project=project, system_tags=["pipeline"], + hyperparams=self.task_hyperparams, ), )