mirror of
https://github.com/clearml/clearml-server
synced 2025-06-25 17:35:47 +00:00
Merge pipeline parameters with original task hyperparameters
This commit is contained in:
parent
fa41e14625
commit
58b748ddf3
@ -39,6 +39,7 @@ from apiserver.database.utils import (
|
|||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
from apiserver.redis_manager import redman
|
from apiserver.redis_manager import redman
|
||||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||||
|
from apiserver.utilities.dicts import nested_set
|
||||||
from .artifacts import artifacts_prepare_for_save
|
from .artifacts import artifacts_prepare_for_save
|
||||||
from .param_utils import params_prepare_for_save
|
from .param_utils import params_prepare_for_save
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -163,18 +164,35 @@ class TaskBLL:
|
|||||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||||
validate_references: bool = False,
|
validate_references: bool = False,
|
||||||
new_project_name: str = None,
|
new_project_name: str = None,
|
||||||
|
hyperparams_overrides: Optional[dict] = None,
|
||||||
|
configuration_overrides: Optional[dict] = None,
|
||||||
) -> Tuple[Task, dict]:
|
) -> Tuple[Task, dict]:
|
||||||
validate_tags(tags, system_tags)
|
validate_tags(tags, system_tags)
|
||||||
params_dict = {
|
task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||||
field: value
|
|
||||||
for field, value in (
|
|
||||||
("hyperparams", hyperparams),
|
|
||||||
("configuration", configuration),
|
|
||||||
)
|
|
||||||
if value is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
params_dict = {}
|
||||||
|
if hyperparams:
|
||||||
|
params_dict["hyperparams"] = hyperparams
|
||||||
|
elif hyperparams_overrides:
|
||||||
|
updated_hyperparams = {
|
||||||
|
sec: {k: value for k, value in sec_data.items()}
|
||||||
|
for sec, sec_data in (task.hyperparams or {}).items()
|
||||||
|
}
|
||||||
|
for section, section_data in hyperparams_overrides.items():
|
||||||
|
for key, value in section_data.items():
|
||||||
|
nested_set(updated_hyperparams, (section, key), value)
|
||||||
|
params_dict["hyperparams"] = updated_hyperparams
|
||||||
|
|
||||||
|
if configuration:
|
||||||
|
params_dict["configuration"] = configuration
|
||||||
|
elif configuration_overrides:
|
||||||
|
updated_configuration = {
|
||||||
|
k: value
|
||||||
|
for k, value in (task.configuration or {}).items()
|
||||||
|
}
|
||||||
|
for key, value in configuration_overrides.items():
|
||||||
|
updated_configuration[key] = value
|
||||||
|
params_dict["configuration"] = updated_configuration
|
||||||
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
if input_models:
|
if input_models:
|
||||||
|
@ -109,7 +109,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
|||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
user_id=call.identity.user,
|
user_id=call.identity.user,
|
||||||
task_id=request.task,
|
task_id=request.task,
|
||||||
hyperparams=hyperparams,
|
hyperparams_overrides=hyperparams,
|
||||||
)
|
)
|
||||||
|
|
||||||
_update_task_name(task)
|
_update_task_name(task)
|
||||||
|
@ -5,6 +5,18 @@ from apiserver.tests.automated import TestService
|
|||||||
|
|
||||||
|
|
||||||
class TestPipelines(TestService):
|
class TestPipelines(TestService):
|
||||||
|
task_hyperparams = {
|
||||||
|
"properties":
|
||||||
|
{
|
||||||
|
"version": {
|
||||||
|
"section": "properties",
|
||||||
|
"name": "version",
|
||||||
|
"type": "str",
|
||||||
|
"value": "3.2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def test_controller_operations(self):
|
def test_controller_operations(self):
|
||||||
task_name = "pipelines test"
|
task_name = "pipelines test"
|
||||||
project, task = self._temp_project_and_task(name=task_name)
|
project, task = self._temp_project_and_task(name=task_name)
|
||||||
@ -82,14 +94,17 @@ class TestPipelines(TestService):
|
|||||||
self.assertEqual(pipeline.status, "queued")
|
self.assertEqual(pipeline.status, "queued")
|
||||||
self.assertEqual(pipeline.project.id, project)
|
self.assertEqual(pipeline.project.id, project)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pipeline.hyperparams.Args,
|
pipeline.hyperparams,
|
||||||
{
|
{
|
||||||
a["name"]: {
|
"Args": {
|
||||||
"section": "Args",
|
a["name"]: {
|
||||||
"name": a["name"],
|
"section": "Args",
|
||||||
"value": a["value"],
|
"name": a["name"],
|
||||||
}
|
"value": a["value"],
|
||||||
for a in args
|
}
|
||||||
|
for a in args
|
||||||
|
},
|
||||||
|
**self.task_hyperparams,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,6 +139,7 @@ class TestPipelines(TestService):
|
|||||||
type="controller",
|
type="controller",
|
||||||
project=project,
|
project=project,
|
||||||
system_tags=["pipeline"],
|
system_tags=["pipeline"],
|
||||||
|
hyperparams=self.task_hyperparams,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user