mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 02:46:53 +00:00
Merge pipeline parameters with original task hyperparameters
This commit is contained in:
parent
fa41e14625
commit
58b748ddf3
@ -39,6 +39,7 @@ from apiserver.database.utils import (
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
@ -163,18 +164,35 @@ class TaskBLL:
|
||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
hyperparams_overrides: Optional[dict] = None,
|
||||
configuration_overrides: Optional[dict] = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
validate_tags(tags, system_tags)
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
params_dict = {}
|
||||
if hyperparams:
|
||||
params_dict["hyperparams"] = hyperparams
|
||||
elif hyperparams_overrides:
|
||||
updated_hyperparams = {
|
||||
sec: {k: value for k, value in sec_data.items()}
|
||||
for sec, sec_data in (task.hyperparams or {}).items()
|
||||
}
|
||||
for section, section_data in hyperparams_overrides.items():
|
||||
for key, value in section_data.items():
|
||||
nested_set(updated_hyperparams, (section, key), value)
|
||||
params_dict["hyperparams"] = updated_hyperparams
|
||||
|
||||
if configuration:
|
||||
params_dict["configuration"] = configuration
|
||||
elif configuration_overrides:
|
||||
updated_configuration = {
|
||||
k: value
|
||||
for k, value in (task.configuration or {}).items()
|
||||
}
|
||||
for key, value in configuration_overrides.items():
|
||||
updated_configuration[key] = value
|
||||
params_dict["configuration"] = updated_configuration
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
|
@ -109,7 +109,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
hyperparams=hyperparams,
|
||||
hyperparams_overrides=hyperparams,
|
||||
)
|
||||
|
||||
_update_task_name(task)
|
||||
|
@ -5,6 +5,18 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestPipelines(TestService):
|
||||
task_hyperparams = {
|
||||
"properties":
|
||||
{
|
||||
"version": {
|
||||
"section": "properties",
|
||||
"name": "version",
|
||||
"type": "str",
|
||||
"value": "3.2"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def test_controller_operations(self):
|
||||
task_name = "pipelines test"
|
||||
project, task = self._temp_project_and_task(name=task_name)
|
||||
@ -82,14 +94,17 @@ class TestPipelines(TestService):
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
pipeline.hyperparams,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
"Args": {
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
**self.task_hyperparams,
|
||||
},
|
||||
)
|
||||
|
||||
@ -124,6 +139,7 @@ class TestPipelines(TestService):
|
||||
type="controller",
|
||||
project=project,
|
||||
system_tags=["pipeline"],
|
||||
hyperparams=self.task_hyperparams,
|
||||
),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user