Sort hyper parameters numeric values as numbers and not strings

This commit is contained in:
allegroai 2020-06-01 12:27:56 +03:00
parent dcdf2a3d58
commit bf7f0f646b
4 changed files with 35 additions and 5 deletions

View File

@ -76,6 +76,8 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_field_collation_overrides = {}
class QueryParameterOptions(object):
def __init__(
self,
@ -476,6 +478,8 @@ class GetMixin(PropsMixin):
"""
Fetch all documents matching a provided query. For the first order by field
the None values are sorted in the end regardless of the sorting order.
If the first order field is a user defined parameter (either from execution.parameters,
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
@ -516,6 +520,16 @@ class GetMixin(PropsMixin):
query_sets = [cls.objects(non_empty), cls.objects(empty)]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
@ -672,5 +686,5 @@ def validate_id(cls, company, **kwargs):
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
)

View File

@ -103,6 +103,11 @@ class TaskType(object):
class Task(AttributedDocument):
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
}
meta = {
"db_alias": Database.backend,
"strict": strict,

View File

@ -475,7 +475,11 @@ get_all {
minimum: 1
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
description: """List of field names to order by. When search_text is used,
'@text_score' can be used as a field representing the text score of returned documents.
Use '-' prefix to specify descending order. Optional, recommended when using page.
If the first order field is a hyper parameter or metric then string values are ordered
according to numeric ordering rules where applicable"""
type: array
items { type: string }
}

View File

@ -61,7 +61,7 @@ class TestEntityOrdering(TestService):
page_size=page_size,
).tasks
def _assertSorted(self, vals: Sequence, ascending=True):
def _assertSorted(self, vals: Sequence, ascending=True, is_numeric=False):
"""
Assert that vals are sorted in the ascending or descending order
with None values are always coming from the end
@ -80,6 +80,9 @@ class TestEntityOrdering(TestService):
self.assertTrue(all(val == empty_value for val in none_tail))
self.assertTrue(all(val != empty_value for val in vals))
if is_numeric:
vals = list(map(int, vals))
if ascending:
cmp = operator.le
else:
@ -106,14 +109,18 @@ class TestEntityOrdering(TestService):
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
self._assertSorted(
field_vals,
ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.")
)
def _create_tasks(self):
tasks = [
self._temp_task(
**(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}}))
)
for i in range(10)
for i in range(20)
]
for idx, task in zip(range(5), tasks):
self.api.tasks.started(task=task)