Add support for field exclusion in get_all endpoints

Add support for ephemeral worker tags (valid while worker has not timed out)
This commit is contained in:
allegroai 2020-08-10 08:48:48 +03:00
parent 8c7e230898
commit cd4ce30f7c
10 changed files with 526 additions and 42 deletions

View File

@ -13,11 +13,5 @@ class GetHyperParamReq(ProjectReq):
page_size = fields.IntField(default=500)
class GetHyperParamResp(models.Base):
parameters = fields.ListField(str)
remaining = fields.IntField()
total = fields.IntField()
class ProjectTagsRequest(TagsRequest):
projects = ListField(str)

View File

@ -1,7 +1,9 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum
from jsonmodels.validators import Enum, Length
from apimodels import DictField, ListField
from apimodels.base import UpdateResponse
@ -103,6 +105,8 @@ class CloneRequest(TaskRequest):
new_task_system_tags = ListField([str])
new_task_parent = StringField()
new_task_project = StringField()
new_hyperparams = DictField()
new_configuration = DictField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
@ -118,3 +122,72 @@ class AddOrUpdateArtifactsResponse(models.Base):
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
class MultiTaskRequest(models.Base):
tasks = ListField([str], validators=Length(minimum_value=1))
class GetHyperParamsRequest(MultiTaskRequest):
pass
class HyperParamItem(models.Base):
section = StringField(required=True, validators=Length(minimum_value=1))
name = StringField(required=True, validators=Length(minimum_value=1))
value = StringField(required=True)
type = StringField()
description = StringField()
class ReplaceHyperparams(object):
none = "none"
section = "section"
all = "all"
class EditHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
replace_hyperparams = StringField(
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
class HyperParamKey(models.Base):
section = StringField(required=True, validators=Length(minimum_value=1))
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
class GetConfigurationsRequest(MultiTaskRequest):
names = ListField([str])
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
class Configuration(models.Base):
name = StringField(required=True, validators=Length(minimum_value=1))
value = StringField(required=True)
type = StringField()
description = StringField()
class EditConfigurationRequest(TaskRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))

View File

@ -19,6 +19,7 @@ DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
class RegisterRequest(WorkerRequest):
@ -74,6 +75,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):

View File

@ -50,6 +50,7 @@ class WorkerBLL:
ip: str = "",
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@ -59,6 +60,7 @@ class WorkerBLL:
:param ip: the real ip of the worker
:param queues: queues reported as being monitored by the worker
:param timeout: registration expiration timeout in seconds
:param tags: a list of tags for this worker
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
:return: worker entry instance
"""
@ -92,6 +94,7 @@ class WorkerBLL:
register_time=now,
register_timeout=timeout,
last_activity_time=now,
tags=tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
@ -114,12 +117,15 @@ class WorkerBLL:
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
:param company_id: worker's company ID
:param user_id: user_id ID under which this worker is running
:param ip: worker IP
:param report: the report itself
:param tags: tags for this worker
:raise bad_request.InvalidTaskId: the reported task was not found
:return: worker entry instance
"""
@ -130,6 +136,9 @@ class WorkerBLL:
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
if report.machine_stats:
self._log_stats_to_es(
company_id=company_id,

View File

@ -1,9 +1,9 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from boltons.iterutils import first, bucketize
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
@ -348,6 +348,17 @@ class GetMixin(PropsMixin):
return []
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod
def split_projection(
cls, projection: Sequence[str]
) -> Tuple[Collection[str], Collection[str]]:
"""Return include and exclude lists based on passed projection and class definition"""
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
@classmethod
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters.pop("only_fields", None)
@ -502,7 +513,7 @@ class GetMixin(PropsMixin):
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
query: Q,
parameters=None,
override_projection=None,
):
@ -525,7 +536,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if search_text:
@ -533,13 +546,14 @@ class GetMixin(PropsMixin):
if order_by:
# add ordering
qs = qs.order_by(*order_by)
if only:
if include:
# add projection
qs = qs.only(*only)
else:
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
qs = qs.only(*include)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
@ -575,7 +589,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
query_sets = [cls.objects(query)]
if order_by:
@ -612,16 +628,15 @@ class GetMixin(PropsMixin):
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if only:
if include:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
else:
exclude = set(cls.get_exclude_fields())
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
query_sets = [qs.only(*include) for qs in query_sets]
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
@ -632,7 +647,8 @@ class GetMixin(PropsMixin):
start -= qs_size
continue
ret.extend(
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
)
if len(ret) >= page_size:
break

View File

@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
)
dst[path_part] = [
copy_path(path_parts[depth + 1:], s, d)
copy_path(path_parts[depth + 1 :], s, d)
for s, d in zip(src_part, dst[path_part])
]
@ -96,6 +96,7 @@ class _ProxyManager:
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
exclusion_prefix = "-"
@property
def doc_projection(self):
@ -128,20 +129,28 @@ class ProjectionHelper(object):
[]
) # Projection information for reference fields (used in join queries)
for field in projection:
field_ = field.lstrip(self.exclusion_prefix)
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
if not field_.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
if field_ == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
subfield = field_[len(ref_field) :]
if not subfield.startswith(SEP):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
ref_projection_info.append(
(
ref_field,
ref_field_cls,
("" if field_[0] == field[0] else self.exclusion_prefix)
+ subfield[1:],
)
)
break
else:
# Not a reference field, just add to the top-level projection
@ -149,7 +158,7 @@ class ProjectionHelper(object):
orig_field = field
if field.endswith(".*"):
field = field[:-2]
if not field:
if not field.lstrip(self.exclusion_prefix):
raise errors.bad_request.InvalidFields(
field=orig_field, object=doc_cls.__name__
)
@ -199,7 +208,7 @@ class ProjectionHelper(object):
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(
doc_cls, doc_projection.difference(ref_projection).union({"id"})
doc_cls, doc_projection.difference(ref_projection)
)
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
@ -218,7 +227,10 @@ class ProjectionHelper(object):
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
f
for f in doc_projection
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
not in doc_cls.get_fields()
]
if invalid_fields:
raise errors.bad_request.InvalidFields(
@ -234,6 +246,13 @@ class ProjectionHelper(object):
doc_projection.add(field)
doc_projection = list(doc_projection)
# If there are include fields (not only exclude) then add an id field
if (
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
and "id" not in doc_projection
):
doc_projection.append("id")
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@ -314,6 +333,7 @@ class ProjectionHelper(object):
]
if items:
def do_projection(item):
ref_field_name, data, ids = item

View File

@ -148,6 +148,11 @@
type: array
items { "$ref": "#/definitions/queue_entry" }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
@ -305,6 +310,11 @@
type: array
items { type: string }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
response {
@ -367,6 +377,11 @@
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
tags {
description: "New user tags for the worker"
type: array
items: { type: string }
}
}
}
response {

View File

@ -46,10 +46,10 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
def register(call: APICall, company_id, req_model: RegisterRequest):
worker = req_model.worker
timeout = req_model.timeout
queues = req_model.queues
def register(call: APICall, company_id, request: RegisterRequest):
worker = request.worker
timeout = request.timeout
queues = request.queues
if not timeout or timeout <= 0:
raise bad_request.WorkerRegistrationFailed(
@ -63,6 +63,7 @@ def register(call: APICall, company_id, req_model: RegisterRequest):
ip=call.real_ip,
queues=queues,
timeout=timeout,
tags=request.tags,
)
@ -78,6 +79,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
user_id=call.identity.user,
ip=call.real_ip,
report=request,
tags=request.tags,
)

View File

@ -6,14 +6,86 @@ log = config.logger(__file__)
class TestProjection(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.6")
def _temp_task(self, **kwargs):
self.update_missing(
kwargs,
type="testing",
name="test projection",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)
def _temp_project(self):
return self.create_temp(
"projects",
name="Test projection",
description="test",
delete_params=dict(force=True),
)
def test_overlapping_fields(self):
message = "task started"
task_id = self.create_temp(
"tasks", name="test", type="testing", input=dict(view=dict())
)
task_id = self._temp_task()
self.api.tasks.started(task=task_id, status_message=message)
task = self.api.tasks.get_all_ex(
id=[task_id], only_fields=["status", "status_message"]
).tasks[0]
assert task["status"] == TaskStatus.in_progress
assert task["status_message"] == message
def test_task_projection(self):
project = self._temp_project()
task1 = self._temp_task(project=project)
task2 = self._temp_task(project=project)
self.api.tasks.started(task=task2, status_message="Started")
res = self.api.tasks.get_all_ex(
project=[project],
only_fields=[
"system_tags",
"company",
"type",
"name",
"tags",
"status",
"project.name",
"user.name",
"started",
"last_update",
"last_iteration",
"comment",
],
order_by=["-started"],
page=0,
page_size=15,
system_tags=["-archived"],
type=[
"__$not",
"annotation_manual",
"__$not",
"annotation",
"__$not",
"dataset_import",
],
).tasks
self.assertEqual([task2, task1], [t.id for t in res])
self.assertEqual("Test projection", res[0].project.name)
def test_exclude_projection(self):
task_id = self._temp_task()
res = self.api.tasks.get_all_ex(
id=[task_id]
).tasks[0]
self.assertEqual("test projection", res.name)
task = self.api.tasks.get_all_ex(
id=[task_id],
only_fields=["-name"]
).tasks[0]
self.assertFalse("name" in task)
self.assertEqual("testing", res.type)

View File

@ -0,0 +1,281 @@
from operator import itemgetter
from typing import Sequence, List, Tuple
from boltons import iterutils
from apierrors.errors.bad_request import InvalidTaskStatus
from tests.api_client import APIClient
from tests.automated import TestService
class TestTasksHyperparams(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.9")
def new_task(self, **kwargs) -> Tuple[str, str]:
if "project" not in kwargs:
kwargs["project"] = self.create_temp(
"projects",
name="Test hyperparams",
description="test",
delete_params=dict(force=True),
)
self.update_missing(
kwargs,
type="testing",
name="test hyperparams",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs), kwargs["project"]
def test_hyperparams(self):
legacy_params = {"legacy$1": "val1", "legacy2/name": "val2"}
new_params = [
dict(section="1/1", name="param1/1", type="type1", value="10"),
dict(section="1/1", name="param2", type="type1", value="20"),
dict(section="2", name="param2", type="type2", value="xxx"),
]
new_params_dict = self._param_dict_from_list(new_params)
task, project = self.new_task(
execution={"parameters": legacy_params}, hyperparams=new_params_dict,
)
# both params and hyper params are set correctly
old_params = self._new_params_from_legacy(legacy_params)
params_dict = new_params_dict.copy()
params_dict["Args"] = {p["name"]: p for p in old_params}
res = self.api.tasks.get_by_id(task=task).task
self.assertEqual(params_dict, res.hyperparams)
# returned as one list with params in the _legacy section
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(new_params + old_params, res.hyperparams)
# replace section
replace_params = [
dict(section="1/1", name="param1", type="type1", value="40"),
dict(section="2", name="param5", type="type1", value="11"),
]
self.api.tasks.edit_hyper_params(
task=task, hyperparams=replace_params, replace_hyperparams="section"
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(replace_params + old_params, res.hyperparams)
# replace all
replace_params = [
dict(section="1/1", name="param1/1", type="type1", value="30"),
dict(section="Args", name="legacy$1", value="123", type="legacy"),
]
self.api.tasks.edit_hyper_params(
task=task, hyperparams=replace_params, replace_hyperparams="all"
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(replace_params, res.hyperparams)
# add and update
self.api.tasks.edit_hyper_params(task=task, hyperparams=new_params + old_params)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(new_params + old_params, res.hyperparams)
# delete
new_to_delete = self._get_param_keys(new_params[1:])
old_to_delete = self._get_param_keys(old_params[:1])
self.api.tasks.delete_hyper_params(
task=task, hyperparams=new_to_delete + old_to_delete
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(new_params[:1] + old_params[1:], res.hyperparams)
# delete section
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[{"section": "1/1"}, {"section": "2"}]
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(old_params[1:], res.hyperparams)
# project hyperparams
res = self.api.projects.get_hyper_parameters(project=project)
self.assertEqual(
[
{k: v for k, v in p.items() if k in ("section", "name")}
for p in old_params[1:]
],
res.parameters,
)
# clone task
new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id
try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams)
finally:
self.api.tasks.delete(task=new_task, force=True)
# editing of started task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="properties", name="x", value="123")]
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="Properties")]
)
@staticmethod
def _get_param_keys(params: Sequence[dict]) -> List[dict]:
return [{k: p[k] for k in ("name", "section")} for p in params]
@staticmethod
def _new_params_from_legacy(legacy: dict) -> List[dict]:
return [
dict(section="Args", name=k, value=str(v), type="legacy")
if not k.startswith("TF_DEFINE/")
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
for k, v in legacy.items()
]
@staticmethod
def _param_dict_from_list(params: Sequence[dict]) -> dict:
return {
k: {v["name"]: v for v in values}
for k, values in iterutils.bucketize(
params, key=itemgetter("section")
).items()
}
@staticmethod
def _config_dict_from_list(config: Sequence[dict]) -> dict:
return {c["name"]: c for c in config}
def test_configuration(self):
legacy_config = {"design": "hello"}
new_config = [
dict(name="param$1", type="type1", value="10"),
dict(name="param/2", type="type1", value="20"),
]
new_config_dict = self._config_dict_from_list(new_config)
task, _ = self.new_task(
execution={"model_desc": legacy_config}, configuration=new_config_dict
)
# both params and hyper params are set correctly
old_config = self._new_config_from_legacy(legacy_config)
config_dict = new_config_dict.copy()
config_dict["design"] = old_config[0]
res = self.api.tasks.get_by_id(task=task).task
self.assertEqual(config_dict, res.configuration)
# returned as one list
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config, res.configuration)
# names
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", "param$1", "param/2"], res.names)
# returned as one list with names filtering
res = self.api.tasks.get_configurations(
tasks=[task], names=[new_config[1]["name"]]
).configurations[0]
self.assertEqual([new_config[1]], res.configuration)
# replace all
replace_configs = [
dict(name="design", value="123", type="legacy"),
dict(name="param/2", type="type1", value="30"),
]
self.api.tasks.edit_configuration(
task=task, configuration=replace_configs, replace_configuration=True
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(replace_configs, res.configuration)
# add and update
self.api.tasks.edit_configuration(
task=task, configuration=new_config + old_config
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config, res.configuration)
# delete
new_to_delete = self._get_config_keys(new_config[1:])
res = self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task
new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id
try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration)
finally:
self.api.tasks.delete(task=new_task, force=True)
@staticmethod
def _get_config_keys(config: Sequence[dict]) -> List[dict]:
return [c["name"] for c in config]
@staticmethod
def _new_config_from_legacy(legacy: dict) -> List[dict]:
return [dict(name=k, value=str(v), type="legacy") for k, v in legacy.items()]
def test_hyperparams_projection(self):
legacy_param = {"legacy.1": "val1"}
new_params1 = [
dict(section="sec.tion1", name="param1", type="type1", value="10")
]
new_params_dict1 = self._param_dict_from_list(new_params1)
task1, project = self.new_task(
execution={"parameters": legacy_param}, hyperparams=new_params_dict1,
)
new_params2 = [
dict(section="sec.tion1", name="param1", type="type1", value="20")
]
new_params_dict2 = self._param_dict_from_list(new_params2)
task2, _ = self.new_task(hyperparams=new_params_dict2, project=project)
old_params = self._new_params_from_legacy(legacy_param)
params_dict = new_params_dict1.copy()
params_dict["Args"] = {p["name"]: p for p in old_params}
res = self.api.tasks.get_all_ex(id=[task1], only_fields=["hyperparams"]).tasks[
0
]
self.assertEqual(params_dict, res.hyperparams)
res = self.api.tasks.get_all_ex(
project=[project],
only_fields=["hyperparams.sec%2Etion1"],
order_by=["-hyperparams.sec%2Etion1"],
).tasks[0]
self.assertEqual(new_params_dict2, res.hyperparams)
def test_old_api(self):
legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"}
legacy_config = {"design": "hello"}
task_id, _ = self.new_task(
execution={"parameters": legacy_params, "model_desc": legacy_config}
)
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
old_api = APIClient(base_url="http://localhost:8008/v2.8")
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(legacy_params, task.execution.parameters)
self.assertEqual(legacy_config, task.execution.model_desc)
self.assertEqual(params, task.hyperparams)
self.assertEqual(config, task.configuration)
modified_params = {"legacy.2": "val2"}
modified_config = {"design": "by"}
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(modified_params, task.execution.parameters)
self.assertEqual(modified_config, task.execution.model_desc)