Fix numeric hyperparam values are not sorted lexicographically with descending sort order

This commit is contained in:
allegroai 2023-05-25 19:15:59 +03:00
parent 4291ad682a
commit 5c5d9b6434
2 changed files with 20 additions and 1 deletions

View File

@ -754,7 +754,9 @@ class GetMixin(PropsMixin):
@classmethod
def _get_collation_override(cls, field: str) -> Optional[dict]:
return first(
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
v
for k, v in cls._field_collation_overrides.items()
if field.startswith(k) or field.startswith(f"-{k}")
)
@classmethod

View File

@ -298,6 +298,23 @@ class TestTasksHyperparams(TestService):
).tasks[0]
self.assertEqual(new_params_dict2, res.hyperparams)
def test_numeric_ordering(self):
params = [
dict(section="section1", name="param1", type="type1", value="1"),
dict(section="section1", name="param1", type="type1", value="2"),
dict(section="section1", name="param1", type="type1", value="11"),
]
tasks = [
self.new_task(hyperparams=self._param_dict_from_list([p]), project=None)[0]
for p in params
]
res = self.api.tasks.get_all_ex(id=tasks, order_by=["hyperparams.section1.param1"]).tasks
self.assertEqual([t.id for t in res], tasks)
res = self.api.tasks.get_all_ex(id=tasks, order_by=["-hyperparams.section1.param1"]).tasks
self.assertEqual([t.id for t in res], list(reversed(tasks)))
def test_old_api(self):
legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"}
legacy_config = {"design": "hello"}