clearml-server/server/database/projection.py
2019-06-11 00:24:35 +03:00

270 lines
12 KiB
Python

from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
import dpath
from apierrors import errors
from database.props import PropsMixin
def project_dict(data, projection, separator='.'):
"""
Project partial data from a dictionary into a new dictionary
:param data: Input dictionary
:param projection: List of dictionary paths (each a string with field names separated using a separator)
:param separator: Separator (default is '.')
:return: A new dictionary containing only the projected parts from the original dictionary
"""
assert isinstance(data, dict)
result = {}
def copy_path(path_parts, source, destination):
src, dst = source, destination
try:
for depth, path_part in enumerate(path_parts[:-1]):
src_part = src[path_part]
if isinstance(src_part, dict):
src = src_part
dst = dst.setdefault(path_part, {})
elif isinstance(src_part, (list, tuple)):
if path_part not in dst:
dst[path_part] = [{} for _ in range(len(src_part))]
elif not isinstance(dst[path_part], (list, tuple)):
raise TypeError('Incompatible destination type %s for %s (list expected)'
% (type(dst), separator.join(path_parts[:depth + 1])))
elif not len(dst[path_part]) == len(src_part):
raise ValueError('Destination list length differs from source length for %s'
% separator.join(path_parts[:depth + 1]))
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
for s, d in zip(src_part, dst[path_part])]
return destination
else:
raise TypeError('Unsupported projection type %s for %s'
% (type(src), separator.join(path_parts[:depth + 1])))
last_part = path_parts[-1]
dst[last_part] = src[last_part]
except KeyError:
# Projection field not in source, no biggie.
pass
return destination
for projection_path in sorted(projection):
copy_path(
path_parts=projection_path.split(separator),
source=data,
destination=result)
return result
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
@property
def doc_projection(self):
return self._doc_projection
def __init__(self, doc_cls, projection, expand_reference_ids=False):
super(ProjectionHelper, self).__init__()
self._should_expand_reference_ids = expand_reference_ids
self._doc_cls = doc_cls
self._doc_projection = None
self._ref_projection = None
self._parse_projection(projection)
def _collect_projection_fields(self, doc_cls, projection):
"""
Collect projection for the given document into immediate document projection and reference documents projection
:param doc_cls: Document class
:param projection: List of projection fields
:return: A tuple of document projection and reference fields information
"""
doc_projection = set() # Projection fields for this class (used in the main query)
ref_projection_info = [] # Projection information for reference fields (used in join queries)
for field in projection:
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
if not subfield.startswith('.'):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
break
else:
# Not a reference field, just add to the top-level projection
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
orig_field = field
if field.endswith('.*'):
field = field[:-2]
if not field:
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
doc_projection.add(field)
return doc_projection, ref_projection_info
def _parse_projection(self, projection):
"""
Prepare the projection data structures for get_many_with_join().
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
should be returned (only relevant for fields that represent sub-documents or referenced documents)
:type projection: list of strings
:returns A tuple of (class fields projection, reference fields projection)
"""
doc_cls = self._doc_cls
assert issubclass(doc_cls, PropsMixin)
if not projection:
return [], {}
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
def normalize_cls_projection(cls_, fields):
""" Normalize projection for this class and group (expand *, for once) """
if '*' in fields:
return list(fields.difference('*').union(cls_.get_fields()))
return list(fields)
def compute_ref_cls_projection(cls_, group):
""" Compute inner projection for this class and group """
subfields = set([x[2] for x in group if x[2]])
return normalize_cls_projection(cls_, subfields)
def sort_key(proj_info):
return proj_info[:2]
# Aggregate by reference field. We'll leave out '*' from the projected items since
ref_projection = {
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
}
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
# won't return some of the data we need.
# This way, we make sure to use the most inclusive field that contains all requested subfields.
projection_set = set(doc_projection)
doc_projection = [
field
for field in doc_projection
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
]
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
if invalid_fields:
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
if ref_projection:
# Join mode - use both normal projection fields and top-level reference fields
doc_projection = set(doc_projection)
for field in set(ref_projection).difference(doc_projection):
if any(f for f in doc_projection if field.startswith(f)):
continue
doc_projection.add(field)
doc_projection = list(doc_projection)
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@staticmethod
def _search(doc_cls, obj, path, only_values=True):
""" Call dpath.search with yielded=True, collect result values """
norm_path = doc_cls.get_dpath_translated_path(path)
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
def project(self, results, projection_func):
"""
Perform projection on query results, using the provided projection func.
:param results: A list of results dictionaries on which projection should be performed
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
results. This callable is used in order to perform sub-queries during projection
:return: Modified results (in-place)
"""
cls = self._doc_cls
ref_projection = self._ref_projection
if ref_projection:
# Join mode - get results for each reference fields projection required (this is the join step)
# Note: this is a recursive step, so we support nested reference fields
def do_projection(item):
ref_field_name, data = item
res = {}
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
for res in results))))
if ids:
doc_type = data['cls']
doc_only = list(filter(None, data['only']))
doc_only = list({'id'} | set(doc_only)) if doc_only else None
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
data['res'] = res
items = list(ref_projection.items())
if len(ref_projection) == 1:
do_projection(items[0])
else:
for _ in self.pool.map(do_projection, items):
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
# will be raised when its value is retrieved from the map() iterator
pass
def do_expand_reference_ids(result, skip_fields=None):
ref_fields = cls.get_reference_fields()
if skip_fields:
ref_fields = set(ref_fields) - set(skip_fields)
self._expand_reference_fields(cls, result, ref_fields)
def merge_projection_result(result):
for ref_field_name, data in ref_projection.items():
res = data.get('res')
if not res:
self._expand_reference_fields(cls, result, [ref_field_name])
continue
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
obj = res.get(value) or {'id': value}
dpath.new(result, path, obj, separator='.')
# any reference field not projected should be expanded
do_expand_reference_ids(result, skip_fields=list(ref_projection))
update_func = merge_projection_result if ref_projection else \
do_expand_reference_ids if self._should_expand_reference_ids else None
if update_func:
for result in results:
update_func(result)
return results
@classmethod
def _expand_reference_fields(cls, doc_cls, result, fields):
for ref_field_name in fields:
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
dpath.set(
result,
path,
{'id': value} if value else {},
separator='.')
@classmethod
def expand_reference_ids(cls, doc_cls, result):
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())