from operator import itemgetter from typing import Sequence, List, Tuple from boltons import iterutils from apierrors.errors.bad_request import InvalidTaskStatus from tests.api_client import APIClient from tests.automated import TestService class TestTasksHyperparams(TestService): def setUp(self, **kwargs): super().setUp(version="2.9") def new_task(self, **kwargs) -> Tuple[str, str]: if "project" not in kwargs: kwargs["project"] = self.create_temp( "projects", name="Test hyperparams", description="test", delete_params=dict(force=True), ) self.update_missing( kwargs, type="testing", name="test hyperparams", input=dict(view=dict()), delete_params=dict(force=True), ) return self.create_temp("tasks", **kwargs), kwargs["project"] def test_hyperparams(self): legacy_params = {"legacy$1": "val1", "legacy2/name": "val2"} new_params = [ dict(section="1/1", name="param1/1", type="type1", value="10"), dict(section="1/1", name="param2", type="type1", value="20"), dict(section="2", name="param2", type="type2", value="xxx"), ] new_params_dict = self._param_dict_from_list(new_params) task, project = self.new_task( execution={"parameters": legacy_params}, hyperparams=new_params_dict, ) # both params and hyper params are set correctly old_params = self._new_params_from_legacy(legacy_params) params_dict = new_params_dict.copy() params_dict["Args"] = {p["name"]: p for p in old_params} res = self.api.tasks.get_by_id(task=task).task self.assertEqual(params_dict, res.hyperparams) # returned as one list with params in the _legacy section res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] self.assertEqual(new_params + old_params, res.hyperparams) # replace section replace_params = [ dict(section="1/1", name="param1", type="type1", value="40"), dict(section="2", name="param5", type="type1", value="11"), ] self.api.tasks.edit_hyper_params( task=task, hyperparams=replace_params, replace_hyperparams="section" ) res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] self.assertEqual(replace_params + old_params, res.hyperparams) # replace all replace_params = [ dict(section="1/1", name="param1/1", type="type1", value="30"), dict(section="Args", name="legacy$1", value="123", type="legacy"), ] self.api.tasks.edit_hyper_params( task=task, hyperparams=replace_params, replace_hyperparams="all" ) res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] self.assertEqual(replace_params, res.hyperparams) # add and update self.api.tasks.edit_hyper_params(task=task, hyperparams=new_params + old_params) res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] self.assertEqual(new_params + old_params, res.hyperparams) # delete new_to_delete = self._get_param_keys(new_params[1:]) old_to_delete = self._get_param_keys(old_params[:1]) self.api.tasks.delete_hyper_params( task=task, hyperparams=new_to_delete + old_to_delete ) res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] self.assertEqual(new_params[:1] + old_params[1:], res.hyperparams) # delete section self.api.tasks.delete_hyper_params( task=task, hyperparams=[{"section": "1/1"}, {"section": "2"}] ) res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] self.assertEqual(old_params[1:], res.hyperparams) # project hyperparams res = self.api.projects.get_hyper_parameters(project=project) self.assertEqual( [ {k: v for k, v in p.items() if k in ("section", "name")} for p in old_params[1:] ], res.parameters, ) # clone task new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id try: res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0] self.assertEqual(new_params, res.hyperparams) finally: self.api.tasks.delete(task=new_task, force=True) # editing of started task self.api.tasks.started(task=task) with self.api.raises(InvalidTaskStatus): self.api.tasks.edit_hyper_params( task=task, hyperparams=[dict(section="test", name="x", value="123")] ) self.api.tasks.edit_hyper_params( task=task, hyperparams=[dict(section="properties", name="x", value="123")] ) self.api.tasks.delete_hyper_params( task=task, hyperparams=[dict(section="Properties")] ) @staticmethod def _get_param_keys(params: Sequence[dict]) -> List[dict]: return [{k: p[k] for k in ("name", "section")} for p in params] @staticmethod def _new_params_from_legacy(legacy: dict) -> List[dict]: return [ dict(section="Args", name=k, value=str(v), type="legacy") if not k.startswith("TF_DEFINE/") else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy") for k, v in legacy.items() ] @staticmethod def _param_dict_from_list(params: Sequence[dict]) -> dict: return { k: {v["name"]: v for v in values} for k, values in iterutils.bucketize( params, key=itemgetter("section") ).items() } @staticmethod def _config_dict_from_list(config: Sequence[dict]) -> dict: return {c["name"]: c for c in config} def test_configuration(self): legacy_config = {"design": "hello"} new_config = [ dict(name="param$1", type="type1", value="10"), dict(name="param/2", type="type1", value="20"), ] new_config_dict = self._config_dict_from_list(new_config) task, _ = self.new_task( execution={"model_desc": legacy_config}, configuration=new_config_dict ) # both params and hyper params are set correctly old_config = self._new_config_from_legacy(legacy_config) config_dict = new_config_dict.copy() config_dict["design"] = old_config[0] res = self.api.tasks.get_by_id(task=task).task self.assertEqual(config_dict, res.configuration) # returned as one list res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] self.assertEqual(old_config + new_config, res.configuration) # names res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0] self.assertEqual(task, res.task) self.assertEqual(["design", "param$1", "param/2"], res.names) # returned as one list with names filtering res = self.api.tasks.get_configurations( tasks=[task], names=[new_config[1]["name"]] ).configurations[0] self.assertEqual([new_config[1]], res.configuration) # replace all replace_configs = [ dict(name="design", value="123", type="legacy"), dict(name="param/2", type="type1", value="30"), ] self.api.tasks.edit_configuration( task=task, configuration=replace_configs, replace_configuration=True ) res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] self.assertEqual(replace_configs, res.configuration) # add and update self.api.tasks.edit_configuration( task=task, configuration=new_config + old_config ) res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] self.assertEqual(old_config + new_config, res.configuration) # delete new_to_delete = self._get_config_keys(new_config[1:]) res = self.api.tasks.delete_configuration( task=task, configuration=new_to_delete ) res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] self.assertEqual(old_config + new_config[:1], res.configuration) # clone task new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id try: res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0] self.assertEqual(new_config, res.configuration) finally: self.api.tasks.delete(task=new_task, force=True) @staticmethod def _get_config_keys(config: Sequence[dict]) -> List[dict]: return [c["name"] for c in config] @staticmethod def _new_config_from_legacy(legacy: dict) -> List[dict]: return [dict(name=k, value=str(v), type="legacy") for k, v in legacy.items()] def test_hyperparams_projection(self): legacy_param = {"legacy.1": "val1"} new_params1 = [ dict(section="sec.tion1", name="param1", type="type1", value="10") ] new_params_dict1 = self._param_dict_from_list(new_params1) task1, project = self.new_task( execution={"parameters": legacy_param}, hyperparams=new_params_dict1, ) new_params2 = [ dict(section="sec.tion1", name="param1", type="type1", value="20") ] new_params_dict2 = self._param_dict_from_list(new_params2) task2, _ = self.new_task(hyperparams=new_params_dict2, project=project) old_params = self._new_params_from_legacy(legacy_param) params_dict = new_params_dict1.copy() params_dict["Args"] = {p["name"]: p for p in old_params} res = self.api.tasks.get_all_ex(id=[task1], only_fields=["hyperparams"]).tasks[ 0 ] self.assertEqual(params_dict, res.hyperparams) res = self.api.tasks.get_all_ex( project=[project], only_fields=["hyperparams.sec%2Etion1"], order_by=["-hyperparams.sec%2Etion1"], ).tasks[0] self.assertEqual(new_params_dict2, res.hyperparams) def test_old_api(self): legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"} legacy_config = {"design": "hello"} task_id, _ = self.new_task( execution={"parameters": legacy_params, "model_desc": legacy_config} ) config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config)) params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params)) old_api = APIClient(base_url="http://localhost:8008/v2.8") task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0] self.assertEqual(legacy_params, task.execution.parameters) self.assertEqual(legacy_config, task.execution.model_desc) self.assertEqual(params, task.hyperparams) self.assertEqual(config, task.configuration) modified_params = {"legacy.2": "val2"} modified_config = {"design": "by"} old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config)) task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0] self.assertEqual(modified_params, task.execution.parameters) self.assertEqual(modified_config, task.execution.model_desc)