From cd4ce30f7c0aa3c715993a5f36bd2b42c251353d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 10 Aug 2020 08:48:48 +0300 Subject: [PATCH] Add support for field exclusion in get_all endpoints Add support for ephemeral worker tags (valid while worker has not timed out) --- server/apimodels/projects.py | 6 - server/apimodels/tasks.py | 75 ++++- server/apimodels/workers.py | 2 + server/bll/workers/__init__.py | 11 +- server/database/model/base.py | 54 ++-- server/database/projection.py | 36 ++- server/schema/services/workers.conf | 15 + server/services/workers.py | 10 +- server/tests/automated/test_projection.py | 78 ++++- .../tests/automated/test_task_hyperparams.py | 281 ++++++++++++++++++ 10 files changed, 526 insertions(+), 42 deletions(-) create mode 100644 server/tests/automated/test_task_hyperparams.py diff --git a/server/apimodels/projects.py b/server/apimodels/projects.py index d1780e5..5201471 100644 --- a/server/apimodels/projects.py +++ b/server/apimodels/projects.py @@ -13,11 +13,5 @@ class GetHyperParamReq(ProjectReq): page_size = fields.IntField(default=500) -class GetHyperParamResp(models.Base): - parameters = fields.ListField(str) - remaining = fields.IntField() - total = fields.IntField() - - class ProjectTagsRequest(TagsRequest): projects = ListField(str) diff --git a/server/apimodels/tasks.py b/server/apimodels/tasks.py index d864552..aa16939 100644 --- a/server/apimodels/tasks.py +++ b/server/apimodels/tasks.py @@ -1,7 +1,9 @@ +from typing import Sequence + import six from jsonmodels import models from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField -from jsonmodels.validators import Enum +from jsonmodels.validators import Enum, Length from apimodels import DictField, ListField from apimodels.base import UpdateResponse @@ -103,6 +105,8 @@ class CloneRequest(TaskRequest): new_task_system_tags = ListField([str]) new_task_parent = StringField() new_task_project = StringField() + new_hyperparams = DictField() + new_configuration = DictField() execution_overrides = DictField() validate_references = BoolField(default=False) @@ -118,3 +122,72 @@ class AddOrUpdateArtifactsResponse(models.Base): class ResetRequest(UpdateRequest): clear_all = BoolField(default=False) + + +class MultiTaskRequest(models.Base): + tasks = ListField([str], validators=Length(minimum_value=1)) + + +class GetHyperParamsRequest(MultiTaskRequest): + pass + + +class HyperParamItem(models.Base): + section = StringField(required=True, validators=Length(minimum_value=1)) + name = StringField(required=True, validators=Length(minimum_value=1)) + value = StringField(required=True) + type = StringField() + description = StringField() + + +class ReplaceHyperparams(object): + none = "none" + section = "section" + all = "all" + + +class EditHyperParamsRequest(TaskRequest): + hyperparams: Sequence[HyperParamItem] = ListField( + [HyperParamItem], validators=Length(minimum_value=1) + ) + replace_hyperparams = StringField( + validators=Enum(*get_options(ReplaceHyperparams)), + default=ReplaceHyperparams.none, + ) + + +class HyperParamKey(models.Base): + section = StringField(required=True, validators=Length(minimum_value=1)) + name = StringField(nullable=True) + + +class DeleteHyperParamsRequest(TaskRequest): + hyperparams: Sequence[HyperParamKey] = ListField( + [HyperParamKey], validators=Length(minimum_value=1) + ) + + +class GetConfigurationsRequest(MultiTaskRequest): + names = ListField([str]) + + +class GetConfigurationNamesRequest(MultiTaskRequest): + pass + + +class Configuration(models.Base): + name = StringField(required=True, validators=Length(minimum_value=1)) + value = StringField(required=True) + type = StringField() + description = StringField() + + +class EditConfigurationRequest(TaskRequest): + configuration: Sequence[Configuration] = ListField( + [Configuration], validators=Length(minimum_value=1) + ) + replace_configuration = BoolField(default=False) + + +class DeleteConfigurationRequest(TaskRequest): + configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1)) diff --git a/server/apimodels/workers.py b/server/apimodels/workers.py index 7fbb950..dead953 100644 --- a/server/apimodels/workers.py +++ b/server/apimodels/workers.py @@ -19,6 +19,7 @@ DEFAULT_TIMEOUT = 10 * 60 class WorkerRequest(Base): worker = StringField(required=True) + tags = ListField(str) class RegisterRequest(WorkerRequest): @@ -74,6 +75,7 @@ class WorkerEntry(Base, JsonSerializableMixin): register_timeout = IntField(required=True) last_activity_time = DateTimeField(required=True) last_report_time = DateTimeField() + tags = ListField(str) class CurrentTaskEntry(IdNameEntry): diff --git a/server/bll/workers/__init__.py b/server/bll/workers/__init__.py index cf04985..fdf2416 100644 --- a/server/bll/workers/__init__.py +++ b/server/bll/workers/__init__.py @@ -50,6 +50,7 @@ class WorkerBLL: ip: str = "", queues: Sequence[str] = None, timeout: int = 0, + tags: Sequence[str] = None, ) -> WorkerEntry: """ Register a worker @@ -59,6 +60,7 @@ class WorkerBLL: :param ip: the real ip of the worker :param queues: queues reported as being monitored by the worker :param timeout: registration expiration timeout in seconds + :param tags: a list of tags for this worker :raise bad_request.InvalidUserId: in case the calling user or company does not exist :return: worker entry instance """ @@ -92,6 +94,7 @@ class WorkerBLL: register_time=now, register_timeout=timeout, last_activity_time=now, + tags=tags, ) self.redis.setex(key, timedelta(seconds=timeout), entry.to_json()) @@ -114,12 +117,15 @@ class WorkerBLL: raise bad_request.WorkerNotRegistered(worker=worker) def status_report( - self, company_id: str, user_id: str, ip: str, report: StatusReportRequest + self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None, ) -> None: """ Write worker status report :param company_id: worker's company ID :param user_id: user_id ID under which this worker is running + :param ip: worker IP + :param report: the report itself + :param tags: tags for this worker :raise bad_request.InvalidTaskId: the reported task was not found :return: worker entry instance """ @@ -130,6 +136,9 @@ class WorkerBLL: now = datetime.utcnow() entry.last_activity_time = now + if tags is not None: + entry.tags = tags + if report.machine_stats: self._log_stats_to_es( company_id=company_id, diff --git a/server/database/model/base.py b/server/database/model/base.py index 9075bf5..3a3885c 100644 --- a/server/database/model/base.py +++ b/server/database/model/base.py @@ -1,9 +1,9 @@ import re from collections import namedtuple from functools import reduce -from typing import Collection, Sequence, Union, Optional, Type +from typing import Collection, Sequence, Union, Optional, Type, Tuple -from boltons.iterutils import first, bucketize +from boltons.iterutils import first, bucketize, partition from dateutil.parser import parse as parse_datetime from mongoengine import Q, Document, ListField, StringField from pymongo.command_cursor import CommandCursor @@ -348,6 +348,17 @@ class GetMixin(PropsMixin): return [] return parameters.get(cls._projection_key) or parameters.get("only_fields", []) + @classmethod + def split_projection( + cls, projection: Sequence[str] + ) -> Tuple[Collection[str], Collection[str]]: + """Return include and exclude lists based on passed projection and class definition""" + include, exclude = partition( + projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix, + ) + exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude} + return include, set(cls.get_exclude_fields()).union(exclude).difference(include) + @classmethod def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]: parameters.pop("only_fields", None) @@ -502,7 +513,7 @@ class GetMixin(PropsMixin): @classmethod def _get_many_no_company( cls: Union["GetMixin", Document], - query, + query: Q, parameters=None, override_projection=None, ): @@ -525,7 +536,9 @@ class GetMixin(PropsMixin): search_text = parameters.get(cls._search_text_key) order_by = cls.validate_order_by(parameters=parameters, search_text=search_text) page, page_size = cls.validate_paging(parameters=parameters) - only = cls.get_projection(parameters, override_projection) + include, exclude = cls.split_projection( + cls.get_projection(parameters, override_projection) + ) qs = cls.objects(query) if search_text: @@ -533,13 +546,14 @@ class GetMixin(PropsMixin): if order_by: # add ordering qs = qs.order_by(*order_by) - if only: + + if include: # add projection - qs = qs.only(*only) - else: - exclude = set(cls.get_exclude_fields()).difference(only) - if exclude: - qs = qs.exclude(*exclude) + qs = qs.only(*include) + + if exclude: + qs = qs.exclude(*exclude) + if page is not None and page_size: # add paging qs = qs.skip(page * page_size).limit(page_size) @@ -575,7 +589,9 @@ class GetMixin(PropsMixin): search_text = parameters.get(cls._search_text_key) order_by = cls.validate_order_by(parameters=parameters, search_text=search_text) page, page_size = cls.validate_paging(parameters=parameters) - only = cls.get_projection(parameters, override_projection) + include, exclude = cls.split_projection( + cls.get_projection(parameters, override_projection) + ) query_sets = [cls.objects(query)] if order_by: @@ -612,16 +628,15 @@ class GetMixin(PropsMixin): if search_text: query_sets = [qs.search_text(search_text) for qs in query_sets] - if only: + if include: # add projection - query_sets = [qs.only(*only) for qs in query_sets] - else: - exclude = set(cls.get_exclude_fields()) - if exclude: - query_sets = [qs.exclude(*exclude) for qs in query_sets] + query_sets = [qs.only(*include) for qs in query_sets] + + if exclude: + query_sets = [qs.exclude(*exclude) for qs in query_sets] if page is None or not page_size: - return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs] + return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs] # add paging ret = [] @@ -632,7 +647,8 @@ class GetMixin(PropsMixin): start -= qs_size continue ret.extend( - obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size) + obj.to_proper_dict(only=include) + for obj in qs.skip(start).limit(page_size) ) if len(ret) >= page_size: break diff --git a/server/database/projection.py b/server/database/projection.py index 9fd4336..a5d77f3 100644 --- a/server/database/projection.py +++ b/server/database/projection.py @@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP): ) dst[path_part] = [ - copy_path(path_parts[depth + 1:], s, d) + copy_path(path_parts[depth + 1 :], s, d) for s, d in zip(src_part, dst[path_part]) ] @@ -96,6 +96,7 @@ class _ProxyManager: class ProjectionHelper(object): pool = ThreadPoolExecutor() + exclusion_prefix = "-" @property def doc_projection(self): @@ -128,20 +129,28 @@ class ProjectionHelper(object): [] ) # Projection information for reference fields (used in join queries) for field in projection: + field_ = field.lstrip(self.exclusion_prefix) for ref_field, ref_field_cls in doc_cls.get_reference_fields().items(): - if not field.startswith(ref_field): + if not field_.startswith(ref_field): # Doesn't start with a reference field continue - if field == ref_field: + if field_ == ref_field: # Field is exactly a reference field. In this case we won't perform any inner projection (for that, # use '.*') continue - subfield = field[len(ref_field):] + subfield = field_[len(ref_field) :] if not subfield.startswith(SEP): # Starts with something that looks like a reference field, but isn't continue - ref_projection_info.append((ref_field, ref_field_cls, subfield[1:])) + ref_projection_info.append( + ( + ref_field, + ref_field_cls, + ("" if field_[0] == field[0] else self.exclusion_prefix) + + subfield[1:], + ) + ) break else: # Not a reference field, just add to the top-level projection @@ -149,7 +158,7 @@ class ProjectionHelper(object): orig_field = field if field.endswith(".*"): field = field[:-2] - if not field: + if not field.lstrip(self.exclusion_prefix): raise errors.bad_request.InvalidFields( field=orig_field, object=doc_cls.__name__ ) @@ -199,7 +208,7 @@ class ProjectionHelper(object): # Make sure this doesn't contain any reference field we'll join anyway # (i.e. in case only_fields=[project, project.name]) doc_projection = normalize_cls_projection( - doc_cls, doc_projection.difference(ref_projection).union({"id"}) + doc_cls, doc_projection.difference(ref_projection) ) # Make sure that in case one or more field is a subfield of another field, we only use the the top-level field. @@ -218,7 +227,10 @@ class ProjectionHelper(object): # Make sure we didn't get any invalid projection fields for this class invalid_fields = [ - f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields() + f + for f in doc_projection + if f.partition(SEP)[0].lstrip(self.exclusion_prefix) + not in doc_cls.get_fields() ] if invalid_fields: raise errors.bad_request.InvalidFields( @@ -234,6 +246,13 @@ class ProjectionHelper(object): doc_projection.add(field) doc_projection = list(doc_projection) + # If there are include fields (not only exclude) then add an id field + if ( + not all(p.startswith(self.exclusion_prefix) for p in doc_projection) + and "id" not in doc_projection + ): + doc_projection.append("id") + self._doc_projection = doc_projection self._ref_projection = ref_projection @@ -314,6 +333,7 @@ class ProjectionHelper(object): ] if items: + def do_projection(item): ref_field_name, data, ids = item diff --git a/server/schema/services/workers.conf b/server/schema/services/workers.conf index 81ef288..8b10b41 100644 --- a/server/schema/services/workers.conf +++ b/server/schema/services/workers.conf @@ -148,6 +148,11 @@ type: array items { "$ref": "#/definitions/queue_entry" } } + tags { + description: "User tags for the worker" + type: array + items: { type: string } + } } } @@ -305,6 +310,11 @@ type: array items { type: string } } + tags { + description: "User tags for the worker" + type: array + items: { type: string } + } } } response { @@ -367,6 +377,11 @@ description: "The machine statistics." "$ref": "#/definitions/machine_stats" } + tags { + description: "New user tags for the worker" + type: array + items: { type: string } + } } } response { diff --git a/server/services/workers.py b/server/services/workers.py index 0630350..6ef9f82 100644 --- a/server/services/workers.py +++ b/server/services/workers.py @@ -46,10 +46,10 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest): @endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest) -def register(call: APICall, company_id, req_model: RegisterRequest): - worker = req_model.worker - timeout = req_model.timeout - queues = req_model.queues +def register(call: APICall, company_id, request: RegisterRequest): + worker = request.worker + timeout = request.timeout + queues = request.queues if not timeout or timeout <= 0: raise bad_request.WorkerRegistrationFailed( @@ -63,6 +63,7 @@ def register(call: APICall, company_id, req_model: RegisterRequest): ip=call.real_ip, queues=queues, timeout=timeout, + tags=request.tags, ) @@ -78,6 +79,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest): user_id=call.identity.user, ip=call.real_ip, report=request, + tags=request.tags, ) diff --git a/server/tests/automated/test_projection.py b/server/tests/automated/test_projection.py index e80faf7..32142fa 100644 --- a/server/tests/automated/test_projection.py +++ b/server/tests/automated/test_projection.py @@ -6,14 +6,86 @@ log = config.logger(__file__) class TestProjection(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.6") + + def _temp_task(self, **kwargs): + self.update_missing( + kwargs, + type="testing", + name="test projection", + input=dict(view=dict()), + delete_params=dict(force=True), + ) + return self.create_temp("tasks", **kwargs) + + def _temp_project(self): + return self.create_temp( + "projects", + name="Test projection", + description="test", + delete_params=dict(force=True), + ) + def test_overlapping_fields(self): message = "task started" - task_id = self.create_temp( - "tasks", name="test", type="testing", input=dict(view=dict()) - ) + task_id = self._temp_task() self.api.tasks.started(task=task_id, status_message=message) task = self.api.tasks.get_all_ex( id=[task_id], only_fields=["status", "status_message"] ).tasks[0] assert task["status"] == TaskStatus.in_progress assert task["status_message"] == message + + def test_task_projection(self): + project = self._temp_project() + task1 = self._temp_task(project=project) + task2 = self._temp_task(project=project) + self.api.tasks.started(task=task2, status_message="Started") + + res = self.api.tasks.get_all_ex( + project=[project], + only_fields=[ + "system_tags", + "company", + "type", + "name", + "tags", + "status", + "project.name", + "user.name", + "started", + "last_update", + "last_iteration", + "comment", + ], + order_by=["-started"], + page=0, + page_size=15, + system_tags=["-archived"], + type=[ + "__$not", + "annotation_manual", + "__$not", + "annotation", + "__$not", + "dataset_import", + ], + ).tasks + self.assertEqual([task2, task1], [t.id for t in res]) + self.assertEqual("Test projection", res[0].project.name) + + def test_exclude_projection(self): + task_id = self._temp_task() + + res = self.api.tasks.get_all_ex( + id=[task_id] + ).tasks[0] + self.assertEqual("test projection", res.name) + + task = self.api.tasks.get_all_ex( + id=[task_id], + only_fields=["-name"] + ).tasks[0] + self.assertFalse("name" in task) + self.assertEqual("testing", res.type) diff --git a/server/tests/automated/test_task_hyperparams.py b/server/tests/automated/test_task_hyperparams.py new file mode 100644 index 0000000..528cdd5 --- /dev/null +++ b/server/tests/automated/test_task_hyperparams.py @@ -0,0 +1,281 @@ +from operator import itemgetter +from typing import Sequence, List, Tuple + +from boltons import iterutils + +from apierrors.errors.bad_request import InvalidTaskStatus +from tests.api_client import APIClient +from tests.automated import TestService + + +class TestTasksHyperparams(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.9") + + def new_task(self, **kwargs) -> Tuple[str, str]: + if "project" not in kwargs: + kwargs["project"] = self.create_temp( + "projects", + name="Test hyperparams", + description="test", + delete_params=dict(force=True), + ) + self.update_missing( + kwargs, + type="testing", + name="test hyperparams", + input=dict(view=dict()), + delete_params=dict(force=True), + ) + return self.create_temp("tasks", **kwargs), kwargs["project"] + + def test_hyperparams(self): + legacy_params = {"legacy$1": "val1", "legacy2/name": "val2"} + new_params = [ + dict(section="1/1", name="param1/1", type="type1", value="10"), + dict(section="1/1", name="param2", type="type1", value="20"), + dict(section="2", name="param2", type="type2", value="xxx"), + ] + new_params_dict = self._param_dict_from_list(new_params) + task, project = self.new_task( + execution={"parameters": legacy_params}, hyperparams=new_params_dict, + ) + # both params and hyper params are set correctly + old_params = self._new_params_from_legacy(legacy_params) + params_dict = new_params_dict.copy() + params_dict["Args"] = {p["name"]: p for p in old_params} + res = self.api.tasks.get_by_id(task=task).task + self.assertEqual(params_dict, res.hyperparams) + + # returned as one list with params in the _legacy section + res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] + self.assertEqual(new_params + old_params, res.hyperparams) + + # replace section + replace_params = [ + dict(section="1/1", name="param1", type="type1", value="40"), + dict(section="2", name="param5", type="type1", value="11"), + ] + self.api.tasks.edit_hyper_params( + task=task, hyperparams=replace_params, replace_hyperparams="section" + ) + res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] + self.assertEqual(replace_params + old_params, res.hyperparams) + + # replace all + replace_params = [ + dict(section="1/1", name="param1/1", type="type1", value="30"), + dict(section="Args", name="legacy$1", value="123", type="legacy"), + ] + self.api.tasks.edit_hyper_params( + task=task, hyperparams=replace_params, replace_hyperparams="all" + ) + res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] + self.assertEqual(replace_params, res.hyperparams) + + # add and update + self.api.tasks.edit_hyper_params(task=task, hyperparams=new_params + old_params) + res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] + self.assertEqual(new_params + old_params, res.hyperparams) + + # delete + new_to_delete = self._get_param_keys(new_params[1:]) + old_to_delete = self._get_param_keys(old_params[:1]) + self.api.tasks.delete_hyper_params( + task=task, hyperparams=new_to_delete + old_to_delete + ) + res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] + self.assertEqual(new_params[:1] + old_params[1:], res.hyperparams) + + # delete section + self.api.tasks.delete_hyper_params( + task=task, hyperparams=[{"section": "1/1"}, {"section": "2"}] + ) + res = self.api.tasks.get_hyper_params(tasks=[task]).params[0] + self.assertEqual(old_params[1:], res.hyperparams) + + # project hyperparams + res = self.api.projects.get_hyper_parameters(project=project) + self.assertEqual( + [ + {k: v for k, v in p.items() if k in ("section", "name")} + for p in old_params[1:] + ], + res.parameters, + ) + + # clone task + new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id + try: + res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0] + self.assertEqual(new_params, res.hyperparams) + finally: + self.api.tasks.delete(task=new_task, force=True) + + # editing of started task + self.api.tasks.started(task=task) + with self.api.raises(InvalidTaskStatus): + self.api.tasks.edit_hyper_params( + task=task, hyperparams=[dict(section="test", name="x", value="123")] + ) + self.api.tasks.edit_hyper_params( + task=task, hyperparams=[dict(section="properties", name="x", value="123")] + ) + self.api.tasks.delete_hyper_params( + task=task, hyperparams=[dict(section="Properties")] + ) + + @staticmethod + def _get_param_keys(params: Sequence[dict]) -> List[dict]: + return [{k: p[k] for k in ("name", "section")} for p in params] + + @staticmethod + def _new_params_from_legacy(legacy: dict) -> List[dict]: + return [ + dict(section="Args", name=k, value=str(v), type="legacy") + if not k.startswith("TF_DEFINE/") + else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy") + for k, v in legacy.items() + ] + + @staticmethod + def _param_dict_from_list(params: Sequence[dict]) -> dict: + return { + k: {v["name"]: v for v in values} + for k, values in iterutils.bucketize( + params, key=itemgetter("section") + ).items() + } + + @staticmethod + def _config_dict_from_list(config: Sequence[dict]) -> dict: + return {c["name"]: c for c in config} + + def test_configuration(self): + legacy_config = {"design": "hello"} + new_config = [ + dict(name="param$1", type="type1", value="10"), + dict(name="param/2", type="type1", value="20"), + ] + new_config_dict = self._config_dict_from_list(new_config) + task, _ = self.new_task( + execution={"model_desc": legacy_config}, configuration=new_config_dict + ) + + # both params and hyper params are set correctly + old_config = self._new_config_from_legacy(legacy_config) + config_dict = new_config_dict.copy() + config_dict["design"] = old_config[0] + res = self.api.tasks.get_by_id(task=task).task + self.assertEqual(config_dict, res.configuration) + + # returned as one list + res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] + self.assertEqual(old_config + new_config, res.configuration) + + # names + res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0] + self.assertEqual(task, res.task) + self.assertEqual(["design", "param$1", "param/2"], res.names) + + # returned as one list with names filtering + res = self.api.tasks.get_configurations( + tasks=[task], names=[new_config[1]["name"]] + ).configurations[0] + self.assertEqual([new_config[1]], res.configuration) + + # replace all + replace_configs = [ + dict(name="design", value="123", type="legacy"), + dict(name="param/2", type="type1", value="30"), + ] + self.api.tasks.edit_configuration( + task=task, configuration=replace_configs, replace_configuration=True + ) + res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] + self.assertEqual(replace_configs, res.configuration) + + # add and update + self.api.tasks.edit_configuration( + task=task, configuration=new_config + old_config + ) + res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] + self.assertEqual(old_config + new_config, res.configuration) + + # delete + new_to_delete = self._get_config_keys(new_config[1:]) + res = self.api.tasks.delete_configuration( + task=task, configuration=new_to_delete + ) + res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] + self.assertEqual(old_config + new_config[:1], res.configuration) + + # clone task + new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id + try: + res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0] + self.assertEqual(new_config, res.configuration) + finally: + self.api.tasks.delete(task=new_task, force=True) + + @staticmethod + def _get_config_keys(config: Sequence[dict]) -> List[dict]: + return [c["name"] for c in config] + + @staticmethod + def _new_config_from_legacy(legacy: dict) -> List[dict]: + return [dict(name=k, value=str(v), type="legacy") for k, v in legacy.items()] + + def test_hyperparams_projection(self): + legacy_param = {"legacy.1": "val1"} + new_params1 = [ + dict(section="sec.tion1", name="param1", type="type1", value="10") + ] + new_params_dict1 = self._param_dict_from_list(new_params1) + task1, project = self.new_task( + execution={"parameters": legacy_param}, hyperparams=new_params_dict1, + ) + + new_params2 = [ + dict(section="sec.tion1", name="param1", type="type1", value="20") + ] + new_params_dict2 = self._param_dict_from_list(new_params2) + task2, _ = self.new_task(hyperparams=new_params_dict2, project=project) + + old_params = self._new_params_from_legacy(legacy_param) + params_dict = new_params_dict1.copy() + params_dict["Args"] = {p["name"]: p for p in old_params} + res = self.api.tasks.get_all_ex(id=[task1], only_fields=["hyperparams"]).tasks[ + 0 + ] + self.assertEqual(params_dict, res.hyperparams) + + res = self.api.tasks.get_all_ex( + project=[project], + only_fields=["hyperparams.sec%2Etion1"], + order_by=["-hyperparams.sec%2Etion1"], + ).tasks[0] + self.assertEqual(new_params_dict2, res.hyperparams) + + def test_old_api(self): + legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"} + legacy_config = {"design": "hello"} + task_id, _ = self.new_task( + execution={"parameters": legacy_params, "model_desc": legacy_config} + ) + config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config)) + params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params)) + + old_api = APIClient(base_url="http://localhost:8008/v2.8") + task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0] + self.assertEqual(legacy_params, task.execution.parameters) + self.assertEqual(legacy_config, task.execution.model_desc) + self.assertEqual(params, task.hyperparams) + self.assertEqual(config, task.configuration) + + modified_params = {"legacy.2": "val2"} + modified_config = {"design": "by"} + old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config)) + task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0] + self.assertEqual(modified_params, task.execution.parameters) + self.assertEqual(modified_config, task.execution.model_desc)