mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 02:46:53 +00:00
320 lines
13 KiB
Python
320 lines
13 KiB
Python
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from itertools import groupby, chain
|
|
from typing import Sequence, Dict, Callable
|
|
|
|
from boltons import iterutils
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.config_repo import config
|
|
from apiserver.database.props import PropsMixin
|
|
|
|
SEP = "."
|
|
max_items_per_fetch = config.get("services._mongo.max_page_size", 500)
|
|
|
|
|
|
class _ReferenceProxy(dict):
|
|
def __init__(self, id):
|
|
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
|
|
|
|
|
|
class _ProxyManager:
|
|
lock = threading.Lock()
|
|
|
|
def __init__(self):
|
|
self._proxies: Dict[str, _ReferenceProxy] = {}
|
|
|
|
def add(self, id):
|
|
with self.lock:
|
|
proxy = self._proxies.get(id)
|
|
if proxy is None:
|
|
proxy = self._proxies[id] = _ReferenceProxy(id)
|
|
return proxy
|
|
|
|
def update(self, result):
|
|
proxy = self._proxies.get(result.get("id"))
|
|
if proxy is not None:
|
|
proxy.update(result)
|
|
|
|
|
|
class ProjectionHelper(object):
|
|
pool = ThreadPoolExecutor()
|
|
exclusion_prefix = "-"
|
|
|
|
@property
|
|
def doc_projection(self):
|
|
return self._doc_projection
|
|
|
|
def __init__(self, doc_cls, projection, expand_reference_ids=False):
|
|
super(ProjectionHelper, self).__init__()
|
|
self._should_expand_reference_ids = expand_reference_ids
|
|
self._doc_cls = doc_cls
|
|
self._doc_projection = None
|
|
self._ref_projection = None
|
|
self._proxy_manager = _ProxyManager()
|
|
|
|
self._parse_projection(projection)
|
|
|
|
def _collect_projection_fields(self, doc_cls, projection):
|
|
"""
|
|
Collect projection for the given document into immediate document projection and reference documents projection
|
|
:param doc_cls: Document class
|
|
:param projection: List of projection fields
|
|
:return: A tuple of document projection and reference fields information
|
|
"""
|
|
doc_projection = (
|
|
set()
|
|
) # Projection fields for this class (used in the main query)
|
|
ref_projection_info = (
|
|
[]
|
|
) # Projection information for reference fields (used in join queries)
|
|
for field in projection:
|
|
field_ = field.lstrip(self.exclusion_prefix)
|
|
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
|
if not field_.startswith(ref_field):
|
|
# Doesn't start with a reference field
|
|
continue
|
|
if field_ == ref_field:
|
|
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
|
# use '<reference field name>.*')
|
|
continue
|
|
subfield = field_[len(ref_field) :]
|
|
if not subfield.startswith(SEP):
|
|
# Starts with something that looks like a reference field, but isn't
|
|
continue
|
|
|
|
ref_projection_info.append(
|
|
(
|
|
ref_field,
|
|
ref_field_cls,
|
|
("" if field_[0] == field[0] else self.exclusion_prefix)
|
|
+ subfield[1:],
|
|
)
|
|
)
|
|
break
|
|
else:
|
|
# Not a reference field, just add to the top-level projection
|
|
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
|
|
orig_field = field
|
|
if field.endswith(".*"):
|
|
field = field[:-2]
|
|
if not field.lstrip(self.exclusion_prefix):
|
|
raise errors.bad_request.InvalidFields(
|
|
field=orig_field, object=doc_cls.__name__
|
|
)
|
|
doc_projection.add(field)
|
|
return doc_projection, ref_projection_info
|
|
|
|
def _parse_projection(self, projection):
|
|
"""
|
|
Prepare the projection data structures for get_many_with_join().
|
|
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
|
|
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
|
|
should be returned (only relevant for fields that represent sub-documents or referenced documents)
|
|
:type projection: list of strings
|
|
:returns A tuple of (class fields projection, reference fields projection)
|
|
"""
|
|
doc_cls = self._doc_cls
|
|
assert issubclass(doc_cls, PropsMixin)
|
|
if not projection:
|
|
return [], {}
|
|
|
|
doc_projection, ref_projection_info = self._collect_projection_fields(
|
|
doc_cls, projection
|
|
)
|
|
|
|
def normalize_cls_projection(cls_, fields):
|
|
""" Normalize projection for this class and group (expand *, for once) """
|
|
if "*" in fields:
|
|
return list(fields.difference("*").union(cls_.get_fields()))
|
|
return list(fields)
|
|
|
|
def compute_ref_cls_projection(cls_, group):
|
|
""" Compute inner projection for this class and group """
|
|
subfields = set([x[2] for x in group if x[2]])
|
|
return normalize_cls_projection(cls_, subfields)
|
|
|
|
def sort_key(proj_info):
|
|
return proj_info[:2]
|
|
|
|
# Aggregate by reference field. We'll leave out '*' from the projected items since
|
|
ref_projection = {
|
|
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
|
|
for (ref_field, ref_cls), g in groupby(
|
|
sorted(ref_projection_info, key=sort_key), sort_key
|
|
)
|
|
}
|
|
|
|
# Make sure this doesn't contain any reference field we'll join anyway
|
|
# (i.e. in case only_fields=[project, project.name])
|
|
doc_projection = normalize_cls_projection(
|
|
doc_cls, doc_projection.difference(ref_projection)
|
|
)
|
|
|
|
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
|
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
|
|
# won't return some of the data we need.
|
|
# This way, we make sure to use the most inclusive field that contains all requested subfields.
|
|
projection_set = set(doc_projection)
|
|
doc_projection = [
|
|
field
|
|
for field in doc_projection
|
|
if not any(
|
|
field.startswith(f"{other_field}.")
|
|
for other_field in projection_set - {field}
|
|
)
|
|
]
|
|
|
|
# Make sure we didn't get any invalid projection fields for this class
|
|
invalid_fields = [
|
|
f
|
|
for f in doc_projection
|
|
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
|
not in doc_cls.get_fields()
|
|
]
|
|
if invalid_fields:
|
|
raise errors.bad_request.InvalidFields(
|
|
fields=invalid_fields, object=doc_cls.__name__
|
|
)
|
|
|
|
if ref_projection:
|
|
# Join mode - use both normal projection fields and top-level reference fields
|
|
doc_projection = set(doc_projection)
|
|
for field in set(ref_projection).difference(doc_projection):
|
|
if any(f for f in doc_projection if field.startswith(f)):
|
|
continue
|
|
doc_projection.add(field)
|
|
doc_projection = list(doc_projection)
|
|
|
|
# If there are include fields (not only exclude) then add an id field
|
|
if (
|
|
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
|
and "id" not in doc_projection
|
|
):
|
|
doc_projection.append("id")
|
|
|
|
self._doc_projection = doc_projection
|
|
self._ref_projection = ref_projection
|
|
|
|
def _search(
|
|
self,
|
|
doc_cls: PropsMixin,
|
|
obj: dict,
|
|
path: str,
|
|
factory: Callable[[str], dict] = None,
|
|
) -> Sequence[str]:
|
|
"""
|
|
Search for a path in the given object, return the list of values found for the
|
|
given path (multiple values may exist if the path is a glob expression)
|
|
:param doc_cls: The document class represented by the object
|
|
:param obj: Data object
|
|
:param path: Path to a leaf in the data object ("." separated, may contain "*")
|
|
(in case the path contains "*", there may be multiple values)
|
|
:param factory: If provided, replace each value found with an instance provided by the factory.
|
|
"""
|
|
norm_path = doc_cls.get_dpath_translated_path(path)
|
|
globlist = norm_path.strip(SEP).split(SEP)
|
|
|
|
def _search_and_replace(target: dict, p: Sequence[str]) -> Sequence[str]:
|
|
parent = None
|
|
for idx, part in enumerate(p):
|
|
if isinstance(target, dict) and part in target:
|
|
parent = target
|
|
target = target[part]
|
|
elif isinstance(target, list) and part == "*":
|
|
return list(
|
|
chain.from_iterable(
|
|
_search_and_replace(t, p[idx + 1 :]) for t in target
|
|
)
|
|
)
|
|
else:
|
|
return []
|
|
|
|
if parent and factory:
|
|
parent[p[-1]] = factory(target)
|
|
return [target]
|
|
|
|
return _search_and_replace(obj, globlist)
|
|
|
|
def project(self, results, projection_func):
|
|
"""
|
|
Perform projection on query results, using the provided projection func.
|
|
:param results: A list of results dictionaries on which projection should be performed
|
|
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
|
|
results. This callable is used in order to perform sub-queries during projection
|
|
:return: Modified results (in-place)
|
|
"""
|
|
cls = self._doc_cls
|
|
ref_projection = self._ref_projection
|
|
|
|
if ref_projection:
|
|
# Join mode - get results for each reference fields projection required (this is the join step)
|
|
# Note: this is a recursive step, so nested reference fields are supported
|
|
|
|
def collect_ids(ref_field_name):
|
|
"""
|
|
Collect unique IDs for the given reference path from all result documents.
|
|
All collected IDs are replaced in the result dictionaries with a reference proxy generated by the
|
|
proxies manager to allow rapid update later on when projection results are obtained.
|
|
"""
|
|
all_ids = (
|
|
self._search(
|
|
cls, res, ref_field_name, factory=self._proxy_manager.add
|
|
)
|
|
for res in results
|
|
)
|
|
return list(filter(None, set(chain.from_iterable(all_ids))))
|
|
|
|
items = [
|
|
tup
|
|
for tup in (
|
|
(*item, collect_ids(item[0])) for item in ref_projection.items()
|
|
)
|
|
if tup[2]
|
|
]
|
|
|
|
if items:
|
|
|
|
def do_projection(item):
|
|
ref_field_name, data, ids = item
|
|
|
|
doc_type = data["cls"]
|
|
doc_only = list(filter(None, data["only"]))
|
|
doc_only = list({"id"} | set(doc_only)) if doc_only else None
|
|
|
|
for ids_chunk in iterutils.chunked_iter(ids, max_items_per_fetch):
|
|
for res in projection_func(
|
|
doc_type=doc_type, projection=doc_only, ids=ids_chunk
|
|
):
|
|
self._proxy_manager.update(res)
|
|
|
|
if len(ref_projection) == 1:
|
|
do_projection(items[0])
|
|
else:
|
|
for _ in self.pool.map(do_projection, items):
|
|
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
|
# will be raised when its value is retrieved from the map() iterator
|
|
pass
|
|
|
|
def do_expand_reference_ids(result, skip_fields=None):
|
|
ref_fields = cls.get_reference_fields()
|
|
if skip_fields:
|
|
ref_fields = set(ref_fields) - set(skip_fields)
|
|
self._expand_reference_fields(cls, result, ref_fields)
|
|
|
|
# any reference field not projected should be expanded
|
|
if self._should_expand_reference_ids:
|
|
for result in results:
|
|
do_expand_reference_ids(
|
|
result, skip_fields=list(ref_projection) if ref_projection else None
|
|
)
|
|
|
|
return results
|
|
|
|
def _expand_reference_fields(self, doc_cls, result, fields):
|
|
for ref_field_name in fields:
|
|
self._search(doc_cls, result, ref_field_name, factory=_ReferenceProxy)
|
|
|
|
def expand_reference_ids(self, doc_cls, result):
|
|
self._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|