mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add support for field exclusion in get_all endpoints
Add support for ephemeral worker tags (valid while worker has not timed out)
This commit is contained in:
parent
8c7e230898
commit
cd4ce30f7c
@ -13,11 +13,5 @@ class GetHyperParamReq(ProjectReq):
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
@ -1,7 +1,9 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apimodels import DictField, ListField
|
||||
from apimodels.base import UpdateResponse
|
||||
@ -103,6 +105,8 @@ class CloneRequest(TaskRequest):
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_hyperparams = DictField()
|
||||
new_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
@ -118,3 +122,72 @@ class AddOrUpdateArtifactsResponse(models.Base):
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
@ -19,6 +19,7 @@ DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@ -74,6 +75,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
|
@ -50,6 +50,7 @@ class WorkerBLL:
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@ -59,6 +60,7 @@ class WorkerBLL:
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@ -92,6 +94,7 @@ class WorkerBLL:
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
@ -114,12 +117,15 @@ class WorkerBLL:
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@ -130,6 +136,9 @@ class WorkerBLL:
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
|
@ -1,9 +1,9 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first, bucketize
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
@ -348,6 +348,17 @@ class GetMixin(PropsMixin):
|
||||
return []
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
@ -502,7 +513,7 @@ class GetMixin(PropsMixin):
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query,
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
@ -525,7 +536,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
@ -533,13 +546,14 @@ class GetMixin(PropsMixin):
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
if only:
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
@ -575,7 +589,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
@ -612,16 +628,15 @@ class GetMixin(PropsMixin):
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if only:
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*only) for qs in query_sets]
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields())
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
@ -632,7 +647,8 @@ class GetMixin(PropsMixin):
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
|
@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1:], s, d)
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
@ -96,6 +96,7 @@ class _ProxyManager:
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
@ -128,20 +129,28 @@ class ProjectionHelper(object):
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
@ -149,7 +158,7 @@ class ProjectionHelper(object):
|
||||
orig_field = field
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
@ -199,7 +208,7 @@ class ProjectionHelper(object):
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection).union({"id"})
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
@ -218,7 +227,10 @@ class ProjectionHelper(object):
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [
|
||||
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(
|
||||
@ -234,6 +246,13 @@ class ProjectionHelper(object):
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@ -314,6 +333,7 @@ class ProjectionHelper(object):
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
|
@ -148,6 +148,11 @@
|
||||
type: array
|
||||
items { "$ref": "#/definitions/queue_entry" }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -305,6 +310,11 @@
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@ -367,6 +377,11 @@
|
||||
description: "The machine statistics."
|
||||
"$ref": "#/definitions/machine_stats"
|
||||
}
|
||||
tags {
|
||||
description: "New user tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
@ -46,10 +46,10 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
|
||||
|
||||
|
||||
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
|
||||
def register(call: APICall, company_id, req_model: RegisterRequest):
|
||||
worker = req_model.worker
|
||||
timeout = req_model.timeout
|
||||
queues = req_model.queues
|
||||
def register(call: APICall, company_id, request: RegisterRequest):
|
||||
worker = request.worker
|
||||
timeout = request.timeout
|
||||
queues = request.queues
|
||||
|
||||
if not timeout or timeout <= 0:
|
||||
raise bad_request.WorkerRegistrationFailed(
|
||||
@ -63,6 +63,7 @@ def register(call: APICall, company_id, req_model: RegisterRequest):
|
||||
ip=call.real_ip,
|
||||
queues=queues,
|
||||
timeout=timeout,
|
||||
tags=request.tags,
|
||||
)
|
||||
|
||||
|
||||
@ -78,6 +79,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
|
||||
user_id=call.identity.user,
|
||||
ip=call.real_ip,
|
||||
report=request,
|
||||
tags=request.tags,
|
||||
)
|
||||
|
||||
|
||||
|
@ -6,14 +6,86 @@ log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestProjection(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.6")
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test projection",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def _temp_project(self):
|
||||
return self.create_temp(
|
||||
"projects",
|
||||
name="Test projection",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
|
||||
def test_overlapping_fields(self):
|
||||
message = "task started"
|
||||
task_id = self.create_temp(
|
||||
"tasks", name="test", type="testing", input=dict(view=dict())
|
||||
)
|
||||
task_id = self._temp_task()
|
||||
self.api.tasks.started(task=task_id, status_message=message)
|
||||
task = self.api.tasks.get_all_ex(
|
||||
id=[task_id], only_fields=["status", "status_message"]
|
||||
).tasks[0]
|
||||
assert task["status"] == TaskStatus.in_progress
|
||||
assert task["status_message"] == message
|
||||
|
||||
def test_task_projection(self):
|
||||
project = self._temp_project()
|
||||
task1 = self._temp_task(project=project)
|
||||
task2 = self._temp_task(project=project)
|
||||
self.api.tasks.started(task=task2, status_message="Started")
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
project=[project],
|
||||
only_fields=[
|
||||
"system_tags",
|
||||
"company",
|
||||
"type",
|
||||
"name",
|
||||
"tags",
|
||||
"status",
|
||||
"project.name",
|
||||
"user.name",
|
||||
"started",
|
||||
"last_update",
|
||||
"last_iteration",
|
||||
"comment",
|
||||
],
|
||||
order_by=["-started"],
|
||||
page=0,
|
||||
page_size=15,
|
||||
system_tags=["-archived"],
|
||||
type=[
|
||||
"__$not",
|
||||
"annotation_manual",
|
||||
"__$not",
|
||||
"annotation",
|
||||
"__$not",
|
||||
"dataset_import",
|
||||
],
|
||||
).tasks
|
||||
self.assertEqual([task2, task1], [t.id for t in res])
|
||||
self.assertEqual("Test projection", res[0].project.name)
|
||||
|
||||
def test_exclude_projection(self):
|
||||
task_id = self._temp_task()
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=[task_id]
|
||||
).tasks[0]
|
||||
self.assertEqual("test projection", res.name)
|
||||
|
||||
task = self.api.tasks.get_all_ex(
|
||||
id=[task_id],
|
||||
only_fields=["-name"]
|
||||
).tasks[0]
|
||||
self.assertFalse("name" in task)
|
||||
self.assertEqual("testing", res.type)
|
||||
|
281
server/tests/automated/test_task_hyperparams.py
Normal file
281
server/tests/automated/test_task_hyperparams.py
Normal file
@ -0,0 +1,281 @@
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, List, Tuple
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apierrors.errors.bad_request import InvalidTaskStatus
|
||||
from tests.api_client import APIClient
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestTasksHyperparams(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.9")
|
||||
|
||||
def new_task(self, **kwargs) -> Tuple[str, str]:
|
||||
if "project" not in kwargs:
|
||||
kwargs["project"] = self.create_temp(
|
||||
"projects",
|
||||
name="Test hyperparams",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test hyperparams",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs), kwargs["project"]
|
||||
|
||||
def test_hyperparams(self):
|
||||
legacy_params = {"legacy$1": "val1", "legacy2/name": "val2"}
|
||||
new_params = [
|
||||
dict(section="1/1", name="param1/1", type="type1", value="10"),
|
||||
dict(section="1/1", name="param2", type="type1", value="20"),
|
||||
dict(section="2", name="param2", type="type2", value="xxx"),
|
||||
]
|
||||
new_params_dict = self._param_dict_from_list(new_params)
|
||||
task, project = self.new_task(
|
||||
execution={"parameters": legacy_params}, hyperparams=new_params_dict,
|
||||
)
|
||||
# both params and hyper params are set correctly
|
||||
old_params = self._new_params_from_legacy(legacy_params)
|
||||
params_dict = new_params_dict.copy()
|
||||
params_dict["Args"] = {p["name"]: p for p in old_params}
|
||||
res = self.api.tasks.get_by_id(task=task).task
|
||||
self.assertEqual(params_dict, res.hyperparams)
|
||||
|
||||
# returned as one list with params in the _legacy section
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(new_params + old_params, res.hyperparams)
|
||||
|
||||
# replace section
|
||||
replace_params = [
|
||||
dict(section="1/1", name="param1", type="type1", value="40"),
|
||||
dict(section="2", name="param5", type="type1", value="11"),
|
||||
]
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=replace_params, replace_hyperparams="section"
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(replace_params + old_params, res.hyperparams)
|
||||
|
||||
# replace all
|
||||
replace_params = [
|
||||
dict(section="1/1", name="param1/1", type="type1", value="30"),
|
||||
dict(section="Args", name="legacy$1", value="123", type="legacy"),
|
||||
]
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=replace_params, replace_hyperparams="all"
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(replace_params, res.hyperparams)
|
||||
|
||||
# add and update
|
||||
self.api.tasks.edit_hyper_params(task=task, hyperparams=new_params + old_params)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(new_params + old_params, res.hyperparams)
|
||||
|
||||
# delete
|
||||
new_to_delete = self._get_param_keys(new_params[1:])
|
||||
old_to_delete = self._get_param_keys(old_params[:1])
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=new_to_delete + old_to_delete
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(new_params[:1] + old_params[1:], res.hyperparams)
|
||||
|
||||
# delete section
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=[{"section": "1/1"}, {"section": "2"}]
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(old_params[1:], res.hyperparams)
|
||||
|
||||
# project hyperparams
|
||||
res = self.api.projects.get_hyper_parameters(project=project)
|
||||
self.assertEqual(
|
||||
[
|
||||
{k: v for k, v in p.items() if k in ("section", "name")}
|
||||
for p in old_params[1:]
|
||||
],
|
||||
res.parameters,
|
||||
)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id
|
||||
try:
|
||||
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
|
||||
self.assertEqual(new_params, res.hyperparams)
|
||||
finally:
|
||||
self.api.tasks.delete(task=new_task, force=True)
|
||||
|
||||
# editing of started task
|
||||
self.api.tasks.started(task=task)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=[dict(section="test", name="x", value="123")]
|
||||
)
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=[dict(section="properties", name="x", value="123")]
|
||||
)
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=[dict(section="Properties")]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_param_keys(params: Sequence[dict]) -> List[dict]:
|
||||
return [{k: p[k] for k in ("name", "section")} for p in params]
|
||||
|
||||
@staticmethod
|
||||
def _new_params_from_legacy(legacy: dict) -> List[dict]:
|
||||
return [
|
||||
dict(section="Args", name=k, value=str(v), type="legacy")
|
||||
if not k.startswith("TF_DEFINE/")
|
||||
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
|
||||
for k, v in legacy.items()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _param_dict_from_list(params: Sequence[dict]) -> dict:
|
||||
return {
|
||||
k: {v["name"]: v for v in values}
|
||||
for k, values in iterutils.bucketize(
|
||||
params, key=itemgetter("section")
|
||||
).items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _config_dict_from_list(config: Sequence[dict]) -> dict:
|
||||
return {c["name"]: c for c in config}
|
||||
|
||||
def test_configuration(self):
|
||||
legacy_config = {"design": "hello"}
|
||||
new_config = [
|
||||
dict(name="param$1", type="type1", value="10"),
|
||||
dict(name="param/2", type="type1", value="20"),
|
||||
]
|
||||
new_config_dict = self._config_dict_from_list(new_config)
|
||||
task, _ = self.new_task(
|
||||
execution={"model_desc": legacy_config}, configuration=new_config_dict
|
||||
)
|
||||
|
||||
# both params and hyper params are set correctly
|
||||
old_config = self._new_config_from_legacy(legacy_config)
|
||||
config_dict = new_config_dict.copy()
|
||||
config_dict["design"] = old_config[0]
|
||||
res = self.api.tasks.get_by_id(task=task).task
|
||||
self.assertEqual(config_dict, res.configuration)
|
||||
|
||||
# returned as one list
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config, res.configuration)
|
||||
|
||||
# names
|
||||
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
|
||||
self.assertEqual(task, res.task)
|
||||
self.assertEqual(["design", "param$1", "param/2"], res.names)
|
||||
|
||||
# returned as one list with names filtering
|
||||
res = self.api.tasks.get_configurations(
|
||||
tasks=[task], names=[new_config[1]["name"]]
|
||||
).configurations[0]
|
||||
self.assertEqual([new_config[1]], res.configuration)
|
||||
|
||||
# replace all
|
||||
replace_configs = [
|
||||
dict(name="design", value="123", type="legacy"),
|
||||
dict(name="param/2", type="type1", value="30"),
|
||||
]
|
||||
self.api.tasks.edit_configuration(
|
||||
task=task, configuration=replace_configs, replace_configuration=True
|
||||
)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(replace_configs, res.configuration)
|
||||
|
||||
# add and update
|
||||
self.api.tasks.edit_configuration(
|
||||
task=task, configuration=new_config + old_config
|
||||
)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config, res.configuration)
|
||||
|
||||
# delete
|
||||
new_to_delete = self._get_config_keys(new_config[1:])
|
||||
res = self.api.tasks.delete_configuration(
|
||||
task=task, configuration=new_to_delete
|
||||
)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config[:1], res.configuration)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id
|
||||
try:
|
||||
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
|
||||
self.assertEqual(new_config, res.configuration)
|
||||
finally:
|
||||
self.api.tasks.delete(task=new_task, force=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_config_keys(config: Sequence[dict]) -> List[dict]:
|
||||
return [c["name"] for c in config]
|
||||
|
||||
@staticmethod
|
||||
def _new_config_from_legacy(legacy: dict) -> List[dict]:
|
||||
return [dict(name=k, value=str(v), type="legacy") for k, v in legacy.items()]
|
||||
|
||||
def test_hyperparams_projection(self):
|
||||
legacy_param = {"legacy.1": "val1"}
|
||||
new_params1 = [
|
||||
dict(section="sec.tion1", name="param1", type="type1", value="10")
|
||||
]
|
||||
new_params_dict1 = self._param_dict_from_list(new_params1)
|
||||
task1, project = self.new_task(
|
||||
execution={"parameters": legacy_param}, hyperparams=new_params_dict1,
|
||||
)
|
||||
|
||||
new_params2 = [
|
||||
dict(section="sec.tion1", name="param1", type="type1", value="20")
|
||||
]
|
||||
new_params_dict2 = self._param_dict_from_list(new_params2)
|
||||
task2, _ = self.new_task(hyperparams=new_params_dict2, project=project)
|
||||
|
||||
old_params = self._new_params_from_legacy(legacy_param)
|
||||
params_dict = new_params_dict1.copy()
|
||||
params_dict["Args"] = {p["name"]: p for p in old_params}
|
||||
res = self.api.tasks.get_all_ex(id=[task1], only_fields=["hyperparams"]).tasks[
|
||||
0
|
||||
]
|
||||
self.assertEqual(params_dict, res.hyperparams)
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
project=[project],
|
||||
only_fields=["hyperparams.sec%2Etion1"],
|
||||
order_by=["-hyperparams.sec%2Etion1"],
|
||||
).tasks[0]
|
||||
self.assertEqual(new_params_dict2, res.hyperparams)
|
||||
|
||||
def test_old_api(self):
|
||||
legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"}
|
||||
legacy_config = {"design": "hello"}
|
||||
task_id, _ = self.new_task(
|
||||
execution={"parameters": legacy_params, "model_desc": legacy_config}
|
||||
)
|
||||
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
|
||||
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
|
||||
|
||||
old_api = APIClient(base_url="http://localhost:8008/v2.8")
|
||||
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(legacy_params, task.execution.parameters)
|
||||
self.assertEqual(legacy_config, task.execution.model_desc)
|
||||
self.assertEqual(params, task.hyperparams)
|
||||
self.assertEqual(config, task.configuration)
|
||||
|
||||
modified_params = {"legacy.2": "val2"}
|
||||
modified_config = {"design": "by"}
|
||||
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
|
||||
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(modified_params, task.execution.parameters)
|
||||
self.assertEqual(modified_config, task.execution.model_desc)
|
Loading…
Reference in New Issue
Block a user