Fix numeric locale

This commit is contained in:
allegroai 2021-05-03 18:04:45 +03:00
parent b99f620073
commit 1a3d3494ce
2 changed files with 56 additions and 18 deletions

View File

@ -78,7 +78,6 @@ class GetMixin(PropsMixin):
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_field_collation_overrides = {}
_numeric_collation = {"locale": "en_US", "numericOrdering": True}
class QueryParameterOptions(object):
def __init__(
@ -211,6 +210,28 @@ class GetMixin(PropsMixin):
pairs = ((field, parameters.pop(field, None)) for field in fields)
return {k: v for k, v in pairs if v is not None}
@classmethod
def _try_convert_to_numeric(cls, value: Union[str, Sequence[str]]):
def convert_str(val: str) -> Union[float, str]:
try:
return float(val)
except ValueError:
return val
if isinstance(value, str):
return convert_str(value)
if isinstance(value, (list, tuple)) and all(isinstance(v, str) for v in value):
return [convert_str(v) for v in value]
return value
@classmethod
def _get_fixed_field_value(cls, field: str, value):
if field.startswith("last_metrics."):
return cls._try_convert_to_numeric(value)
return value
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@ -233,7 +254,9 @@ class GetMixin(PropsMixin):
dict_query = {}
query = RegexQ()
if parameters:
parameters = parameters.copy()
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
@ -509,6 +532,12 @@ class GetMixin(PropsMixin):
return helper.project(results, projection_func)
@classmethod
def _get_collation_override(cls, field: str) -> Optional[dict]:
return first(
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
)
@classmethod
def get_many(
cls,
@ -546,6 +575,13 @@ class GetMixin(PropsMixin):
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:return: A list of objects matching the query.
"""
override_collation = None
if query_dict:
for field in query_dict:
override_collation = cls._get_collation_override(field)
if override_collation:
break
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
@ -562,10 +598,14 @@ class GetMixin(PropsMixin):
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
return cls._get_many_no_company(
query=_query, parameters=parameters, override_projection=override_projection
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
@classmethod
@ -589,6 +629,7 @@ class GetMixin(PropsMixin):
query: Q,
parameters=None,
override_projection=None,
override_collation=None,
):
"""
Fetch all documents matching a provided query.
@ -608,12 +649,16 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
if order_by and not override_collation:
override_collation = cls._get_collation_override(order_by[0])
page, page_size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if override_collation:
qs = qs.collation(collation=override_collation)
if search_text:
qs = qs.search_text(search_text)
if order_by:
@ -668,6 +713,7 @@ class GetMixin(PropsMixin):
query: Q = None,
parameters: dict = None,
override_projection: Collection[str] = None,
override_collation: dict = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
@ -704,21 +750,13 @@ class GetMixin(PropsMixin):
if res:
query_sets = [cls.objects(q) for q in res]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
# if order_field:
# collation_override = first(
# v
# for k, v in cls._field_collation_overrides.items()
# if order_field.startswith(k)
# )
# if collation_override:
# query_sets = [
# qs.collation(collation=collation_override) for qs in query_sets
# ]
if order_field and not override_collation:
override_collation = cls._get_collation_override(order_field)
# always use numeric collation
query_sets = [
qs.collation(collation=cls._numeric_collation) for qs in query_sets
]
if override_collation:
query_sets = [
qs.collation(collation=override_collation) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]

View File

@ -163,7 +163,6 @@ class Task(AttributedDocument):
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
}
meta = {
@ -184,6 +183,7 @@ class Task(AttributedDocument):
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{"fields": ["company", "project"], "collation": _numeric_locale},
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [