Sort hyper parameters numeric values as numbers and not strings

This commit is contained in:
allegroai 2020-06-01 12:27:56 +03:00
parent dcdf2a3d58
commit bf7f0f646b
4 changed files with 35 additions and 5 deletions

View File

@ -76,6 +76,8 @@ class GetMixin(PropsMixin):
} }
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields") MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_field_collation_overrides = {}
class QueryParameterOptions(object): class QueryParameterOptions(object):
def __init__( def __init__(
self, self,
@ -476,6 +478,8 @@ class GetMixin(PropsMixin):
""" """
Fetch all documents matching a provided query. For the first order by field Fetch all documents matching a provided query. For the first order by field
the None values are sorted in the end regardless of the sorting order. the None values are sorted in the end regardless of the sorting order.
If the first order field is a user defined parameter (either from execution.parameters,
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
This is a company-less version for internal uses. We assume the caller has either added any necessary This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required. constraints to the query or that no constraints are required.
@ -516,6 +520,16 @@ class GetMixin(PropsMixin):
query_sets = [cls.objects(non_empty), cls.objects(empty)] query_sets = [cls.objects(non_empty), cls.objects(empty)]
query_sets = [qs.order_by(*order_by) for qs in query_sets] query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if search_text: if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets] query_sets = [qs.search_text(search_text) for qs in query_sets]
@ -672,5 +686,5 @@ def validate_id(cls, company, **kwargs):
id_to_name.setdefault(obj_id, []).append(name) id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError( raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()), "Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]} **{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
) )

View File

@ -103,6 +103,11 @@ class TaskType(object):
class Task(AttributedDocument): class Task(AttributedDocument):
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
}
meta = { meta = {
"db_alias": Database.backend, "db_alias": Database.backend,
"strict": strict, "strict": strict,

View File

@ -475,7 +475,11 @@ get_all {
minimum: 1 minimum: 1
} }
order_by { order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page" description: """List of field names to order by. When search_text is used,
'@text_score' can be used as a field representing the text score of returned documents.
Use '-' prefix to specify descending order. Optional, recommended when using page.
If the first order field is a hyper parameter or metric then string values are ordered
according to numeric ordering rules where applicable"""
type: array type: array
items { type: string } items { type: string }
} }

View File

@ -61,7 +61,7 @@ class TestEntityOrdering(TestService):
page_size=page_size, page_size=page_size,
).tasks ).tasks
def _assertSorted(self, vals: Sequence, ascending=True): def _assertSorted(self, vals: Sequence, ascending=True, is_numeric=False):
""" """
Assert that vals are sorted in the ascending or descending order Assert that vals are sorted in the ascending or descending order
with None values are always coming from the end with None values are always coming from the end
@ -80,6 +80,9 @@ class TestEntityOrdering(TestService):
self.assertTrue(all(val == empty_value for val in none_tail)) self.assertTrue(all(val == empty_value for val in none_tail))
self.assertTrue(all(val != empty_value for val in vals)) self.assertTrue(all(val != empty_value for val in vals))
if is_numeric:
vals = list(map(int, vals))
if ascending: if ascending:
cmp = operator.le cmp = operator.le
else: else:
@ -106,14 +109,18 @@ class TestEntityOrdering(TestService):
# test that the output is correctly ordered # test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:] field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks] field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
self._assertSorted(field_vals, ascending=not order_by.startswith("-")) self._assertSorted(
field_vals,
ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.")
)
def _create_tasks(self): def _create_tasks(self):
tasks = [ tasks = [
self._temp_task( self._temp_task(
**(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}})) **(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}}))
) )
for i in range(10) for i in range(20)
] ]
for idx, task in zip(range(5), tasks): for idx, task in zip(range(5), tasks):
self.api.tasks.started(task=task) self.api.tasks.started(task=task)