From 901ec372909c7fcc1aba858c5caafec268bd4017 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jul 2020 22:05:36 +0300 Subject: [PATCH] Improve pre-populate on server startup (including sync lock) --- docker-compose.yml | 6 +- server/bll/event/event_bll.py | 2 +- server/bll/task/task_bll.py | 2 - server/config/default/apiserver.conf | 13 +- server/database/model/task/task.py | 1 + server/mongo/initialize/__init__.py | 64 ++- server/mongo/initialize/pre_populate.py | 499 ++++++++++++++++++++---- server/server.py | 16 +- server/services/projects.py | 3 +- server/services/tasks.py | 3 +- server/sync.py | 28 ++ 11 files changed, 525 insertions(+), 112 deletions(-) create mode 100644 server/sync.py diff --git a/docker-compose.yml b/docker-compose.yml index 822c9e6..1c77092 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,7 @@ services: volumes: - /opt/trains/logs:/var/log/trains - /opt/trains/config:/opt/trains/config + - /opt/trains/data/fileserver:/mnt/fileserver depends_on: - redis - mongo @@ -23,8 +24,9 @@ services: TRAINS_REDIS_SERVICE_HOST: redis TRAINS_REDIS_SERVICE_PORT: 6379 TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux} - TRAINS__apiserver__mongo__pre_populate__enabled: "true" - TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip" + TRAINS__apiserver__pre_populate__enabled: "true" + TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate" + TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver" ports: - "8008:8008" networks: diff --git a/server/bll/event/event_bll.py b/server/bll/event/event_bll.py index 52ff6e6..428aff7 100644 --- a/server/bll/event/event_bll.py +++ b/server/bll/event/event_bll.py @@ -552,7 +552,7 @@ class EventBLL(object): ) events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] - next_scroll_id = es_res["_scroll_id"] + next_scroll_id = es_res.get("_scroll_id") total_events = es_res["hits"]["total"] return TaskEventsResult( diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index 6f69ccf..ef80fa5 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -32,7 +32,6 @@ from database.model.task.task import ( ) from database.utils import get_company_or_none_constraint, id as create_id from service_repo import APICall -from services.utils import validate_tags from timing_context import TimingContext from utilities.dicts import deep_merge from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper @@ -182,7 +181,6 @@ class TaskBLL(object): execution_overrides: Optional[dict] = None, validate_references: bool = False, ) -> Task: - validate_tags(tags, system_tags) task = cls.get_by_id(company_id=company_id, task_id=task_id) execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_model_overriden = False diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index c39041d..d1cd078 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -26,6 +26,13 @@ check_max_version: false } + pre_populate { + enabled: false + zip_files: ["/path/to/export.zip"] + fail_on_error: false + artifacts_path: "/mnt/fileserver" + } + mongo { # controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data # but not declared in a data model @@ -34,12 +41,6 @@ aggregate { allow_disk_use: true } - - pre_populate { - enabled: false - zip_file: "/path/to/export.zip" - fail_on_error: false - } } auth { diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index 8601c8e..2ee57f9 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -85,6 +85,7 @@ class Artifact(EmbeddedDocument): class Execution(EmbeddedDocument, ProperDictMixin): + meta = {"strict": strict} test_split = IntField(default=0) parameters = SafeDictField(default=dict) model = StringField(reference_field="Model") diff --git a/server/mongo/initialize/__init__.py b/server/mongo/initialize/__init__.py index e506aaf..d5bd9ec 100644 --- a/server/mongo/initialize/__init__.py +++ b/server/mongo/initialize/__init__.py @@ -1,6 +1,8 @@ from pathlib import Path +from typing import Sequence, Union from config import config +from config.info import get_default_company from database.model.auth import Role from service_repo.auth.fixed_user import FixedUser from .migration import _apply_migrations @@ -11,7 +13,48 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid log = config.logger(__package__) -def init_mongo_data(): +def _pre_populate(company_id: str, zip_file: str): + if not zip_file or not Path(zip_file).is_file(): + msg = f"Invalid pre-populate zip file: {zip_file}" + if config.get("apiserver.pre_populate.fail_on_error", False): + log.error(msg) + raise ValueError(msg) + else: + log.warning(msg) + else: + log.info(f"Pre-populating using {zip_file}") + + user_id = _ensure_backend_user( + "__allegroai__", company_id, "Allegro.ai" + ) + + PrePopulate.import_from_zip( + zip_file, + company_id="", + user_id=user_id, + artifacts_path=config.get( + "apiserver.pre_populate.artifacts_path", None + ), + ) + + +def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]: + if isinstance(zip_files, str): + zip_files = [zip_files] + for p in map(Path, zip_files): + if p.is_file(): + yield p + if p.is_dir(): + yield from p.glob("*.zip") + log.warning(f"Invalid pre-populate entry {str(p)}, skipping") + + +def pre_populate_data(): + for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")): + _pre_populate(company_id=get_default_company(), zip_file=zip_file) + + +def init_mongo_data() -> bool: try: empty_dbs = _apply_migrations(log) @@ -21,23 +64,6 @@ def init_mongo_data(): _ensure_default_queue(company_id) - if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False): - zip_file = config.get("apiserver.mongo.pre_populate.zip_file") - if not zip_file or not Path(zip_file).is_file(): - msg = f"Failed pre-populating database: invalid zip file {zip_file}" - if config.get("apiserver.mongo.pre_populate.fail_on_error", False): - log.error(msg) - raise ValueError(msg) - else: - log.warning(msg) - else: - - user_id = _ensure_backend_user( - "__allegroai__", company_id, "Allegro.ai" - ) - - PrePopulate.import_from_zip(zip_file, user_id=user_id) - fixed_mode = FixedUser.enabled() for user, credentials in config.get("secure.credentials", {}).items(): @@ -61,5 +87,7 @@ def init_mongo_data(): ensure_fixed_user(user, company_id, log=log) except Exception as ex: log.error(f"Failed creating fixed user {user.name}: {ex}") + + return empty_dbs except Exception as ex: log.exception("Failed initializing mongodb") diff --git a/server/mongo/initialize/pre_populate.py b/server/mongo/initialize/pre_populate.py index 3035d62..e69ed44 100644 --- a/server/mongo/initialize/pre_populate.py +++ b/server/mongo/initialize/pre_populate.py @@ -1,31 +1,199 @@ +import hashlib import importlib +import os from collections import defaultdict -from datetime import datetime +from datetime import datetime, timezone +from io import BytesIO +from itertools import chain +from operator import attrgetter from os.path import splitext -from typing import List, Optional, Any, Type, Set, Dict +from pathlib import Path +from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union +from urllib.parse import unquote, urlparse from zipfile import ZipFile, ZIP_BZIP2 +import attr import mongoengine -from tqdm import tqdm +from boltons.iterutils import chunked_iter +from furl import furl +from mongoengine import Q + +from bll.event import EventBLL +from database.model import EntityVisibility +from database.model.model import Model +from database.model.project import Project +from database.model.task.task import Task, ArtifactModes, TaskStatus +from database.utils import get_options +from utilities import json class PrePopulate: - @classmethod - def export_to_zip( - cls, filename: str, experiments: List[str] = None, projects: List[str] = None - ): - with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile: - cls._export(zfile, experiments, projects) + event_bll = EventBLL() + events_file_suffix = "_events" + export_tag_prefix = "Exported:" + export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" + + class JsonLinesWriter: + def __init__(self, file: BinaryIO): + self.file = file + self.empty = True + + def __enter__(self): + self._write("[") + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self._write("\n]") + + def _write(self, data: str): + self.file.write(data.encode("utf-8")) + + def write(self, line: str): + if not self.empty: + self._write(",") + self._write("\n" + line) + self.empty = False + + @attr.s(auto_attribs=True) + class _MapData: + files: Sequence[str] = None + entities: Dict[str, datetime] = None + + @staticmethod + def _get_last_update_time(entity) -> datetime: + return getattr(entity, "last_update", None) or getattr(entity, "created") @classmethod - def import_from_zip(cls, filename: str, user_id: str = None): + def _check_for_update( + cls, map_file: Path, entities: dict + ) -> Tuple[bool, Sequence[str]]: + if not map_file.is_file(): + return True, [] + + files = [] + try: + map_data = cls._MapData(**json.loads(map_file.read_text())) + files = map_data.files + for file in files: + if not Path(file).is_file(): + return True, files + + new_times = { + item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc) + for item in chain.from_iterable(entities.values()) + } + old_times = map_data.entities + + if set(new_times.keys()) != set(old_times.keys()): + return True, files + + for id_, new_timestamp in new_times.items(): + if new_timestamp != old_times[id_]: + return True, files + except Exception as ex: + print("Error reading map file. " + str(ex)) + return True, files + + return False, files + + @classmethod + def _write_update_file( + cls, map_file: Path, entities: dict, created_files: Sequence[str] + ): + map_data = cls._MapData( + files=created_files, + entities={ + entity.id: cls._get_last_update_time(entity) + for entity in chain.from_iterable(entities.values()) + }, + ) + map_file.write_text(json.dumps(attr.asdict(map_data))) + + @staticmethod + def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]: + def is_fileserver_link(a: str) -> bool: + a = a.lower() + if a.startswith("https://files."): + return True + if a.startswith("http"): + parsed = urlparse(a) + if parsed.scheme in {"http", "https"} and parsed.port == 8081: + return True + return False + + fileserver_links = [a for a in artifacts if is_fileserver_link(a)] + print( + f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total" + ) + + return fileserver_links + + @classmethod + def export_to_zip( + cls, + filename: str, + experiments: Sequence[str] = None, + projects: Sequence[str] = None, + artifacts_path: str = None, + task_statuses: Sequence[str] = None, + tag_exported_entities: bool = False, + ) -> Sequence[str]: + if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)): + raise ValueError("Invalid task statuses") + + file = Path(filename) + entities = cls._resolve_entities( + experiments=experiments, projects=projects, task_statuses=task_statuses + ) + + map_file = file.with_suffix(".map") + updated, old_files = cls._check_for_update(map_file, entities) + if not updated: + print(f"There are no updates from the last export") + return old_files + for old in old_files: + old_path = Path(old) + if old_path.is_file(): + old_path.unlink() + + zip_args = dict(mode="w", compression=ZIP_BZIP2) + with ZipFile(file, **zip_args) as zfile: + artifacts, hash_ = cls._export( + zfile, entities, tag_entities=tag_exported_entities + ) + file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}") + file.replace(file_with_hash) + created_files = [str(file_with_hash)] + + artifacts = cls._filter_artifacts(artifacts) + if artifacts and artifacts_path and os.path.isdir(artifacts_path): + artifacts_file = file_with_hash.with_suffix(".artifacts") + with ZipFile(artifacts_file, **zip_args) as zfile: + cls._export_artifacts(zfile, artifacts, artifacts_path) + created_files.append(str(artifacts_file)) + + cls._write_update_file(map_file, entities, created_files) + + return created_files + + @classmethod + def import_from_zip( + cls, filename: str, company_id: str, user_id: str, artifacts_path: str + ): with ZipFile(filename) as zfile: - cls._import(zfile, user_id) + cls._import(zfile, company_id, user_id) + + if artifacts_path and os.path.isdir(artifacts_path): + artifacts_file = Path(filename).with_suffix(".artifacts") + if artifacts_file.is_file(): + print(f"Unzipping artifacts into {artifacts_path}") + with ZipFile(artifacts_file) as zfile: + zfile.extractall(artifacts_path) @staticmethod def _resolve_type( - cls: Type[mongoengine.Document], ids: Optional[List[str]] - ) -> List[Any]: + cls: Type[mongoengine.Document], ids: Optional[Sequence[str]] + ) -> Sequence[Any]: ids = set(ids) items = list(cls.objects(id__in=list(ids))) resolved = {i.id for i in items} @@ -43,20 +211,24 @@ class PrePopulate: @classmethod def _resolve_entities( - cls, experiments: List[str] = None, projects: List[str] = None + cls, + experiments: Sequence[str] = None, + projects: Sequence[str] = None, + task_statuses: Sequence[str] = None, ) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]: - from database.model.project import Project - from database.model.task.task import Task - entities = defaultdict(set) if projects: print("Reading projects...") entities[Project].update(cls._resolve_type(Project, projects)) print("--> Reading project experiments...") - objs = Task.objects( - project__in=list(set(filter(None, (p.id for p in entities[Project])))) + query = Q( + project__in=list(set(filter(None, (p.id for p in entities[Project])))), + system_tags__nin=[EntityVisibility.archived.value], ) + if task_statuses: + query &= Q(status__in=list(set(task_statuses))) + objs = Task.objects(query) entities[Task].update(o for o in objs if o.id not in (experiments or [])) if experiments: @@ -69,85 +241,256 @@ class PrePopulate: project_ids = {p.id for p in entities[Project]} entities[Project].update(o for o in objs if o.id not in project_ids) + model_ids = { + model_id + for task in entities[Task] + for model_id in (task.output.model, task.execution.model) + if model_id + } + if model_ids: + print("Reading models...") + entities[Model] = set(Model.objects(id__in=list(model_ids))) + return entities @classmethod - def _cleanup_task(cls, task): - from database.model.task.task import TaskStatus + def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]: + if not tags: + return tags + return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)] - task.completed = None - task.started = None - if task.execution: - task.execution.model = None - task.execution.model_desc = None - task.execution.model_labels = None - if task.output: - task.output.model = None + @classmethod + def _cleanup_model(cls, model: Model): + model.company = "" + model.user = "" + model.tags = cls._filter_out_export_tags(model.tags) - task.status = TaskStatus.created + @classmethod + def _cleanup_task(cls, task: Task): task.comment = "Auto generated by Allegro.ai" - task.created = datetime.utcnow() - task.last_iteration = 0 - task.last_update = task.created - task.status_changed = task.created task.status_message = "" task.status_reason = "" task.user = "" + task.company = "" + task.tags = cls._filter_out_export_tags(task.tags) + if task.output: + task.output.destination = None + + @classmethod + def _cleanup_project(cls, project: Project): + project.user = "" + project.company = "" + project.tags = cls._filter_out_export_tags(project.tags) @classmethod def _cleanup_entity(cls, entity_cls, entity): - from database.model.task.task import Task if entity_cls == Task: cls._cleanup_task(entity) + elif entity_cls == Model: + cls._cleanup_model(entity) + elif entity == Project: + cls._cleanup_project(entity) + + @classmethod + def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str): + try: + for item in items: + item.update(upsert=False, tags=sorted(item.tags + [tag])) + except AttributeError: + pass + + @classmethod + def _export_task_events( + cls, task: Task, base_filename: str, writer: ZipFile, hash_ + ) -> Sequence[str]: + artifacts = [] + filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json" + print(f"Writing task events into {writer.filename}:{filename}") + with BytesIO() as f: + with cls.JsonLinesWriter(f) as w: + scroll_id = None + while True: + res = cls.event_bll.get_task_events( + task.company, task.id, scroll_id=scroll_id + ) + if not res.events: + break + scroll_id = res.next_scroll_id + for event in res.events: + if event.get("type") == "training_debug_image": + url = cls._get_fixed_url(event.get("url")) + if url: + event["url"] = url + artifacts.append(url) + w.write(json.dumps(event)) + data = f.getvalue() + hash_.update(data) + writer.writestr(filename, data) + + return artifacts + + @staticmethod + def _get_fixed_url(url: Optional[str]) -> Optional[str]: + if not (url and url.lower().startswith("s3://")): + return url + try: + fixed = furl(url) + fixed.scheme = "https" + fixed.host += ".s3.amazonaws.com" + return fixed.url + except Exception as ex: + print(f"Failed processing link {url}. " + str(ex)) + return url + + @classmethod + def _export_entity_related_data( + cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_ + ): + if entity_cls == Task: + return [ + *cls._get_task_output_artifacts(entity), + *cls._export_task_events(entity, base_filename, writer, hash_), + ] + + if entity_cls == Model: + entity.uri = cls._get_fixed_url(entity.uri) + return [entity.uri] if entity.uri else [] + + return [] + + @classmethod + def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]: + if not task.execution.artifacts: + return [] + + for a in task.execution.artifacts: + if a.mode == ArtifactModes.output: + a.uri = cls._get_fixed_url(a.uri) + + return [ + a.uri + for a in task.execution.artifacts + if a.mode == ArtifactModes.output and a.uri + ] + + @classmethod + def _export_artifacts( + cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str + ): + unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts) + print(f"Writing {len(unique_paths)} artifacts into {writer.filename}") + for path in unique_paths: + path = path.lstrip("/") + full_path = os.path.join(artifacts_path, path) + if os.path.isfile(full_path): + writer.write(full_path, path) + else: + print(f"Artifact {full_path} not found") @classmethod def _export( - cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None - ): - entities = cls._resolve_entities(experiments, projects) - - for cls_, items in entities.items(): + cls, writer: ZipFile, entities: dict, tag_entities: bool = False + ) -> Tuple[Sequence[str], str]: + """ + Export the requested experiments, projects and models and return the list of artifact files + Always do the export on sorted items since the order of items influence hash + """ + artifacts = [] + now = datetime.utcnow() + hash_ = hashlib.md5() + for cls_ in sorted(entities, key=attrgetter("__name__")): + items = sorted(entities[cls_], key=attrgetter("id")) if not items: continue - filename = f"{cls_.__module__}.{cls_.__name__}.json" + base_filename = f"{cls_.__module__}.{cls_.__name__}" + for item in items: + artifacts.extend( + cls._export_entity_related_data( + cls_, item, base_filename, writer, hash_ + ) + ) + filename = base_filename + ".json" print(f"Writing {len(items)} items into {writer.filename}:{filename}") - with writer.open(filename, "w") as f: - f.write("[\n".encode("utf-8")) - last = len(items) - 1 - for i, item in enumerate(items): - cls._cleanup_entity(cls_, item) - f.write(item.to_json().encode("utf-8")) - if i != last: - f.write(",".encode("utf-8")) - f.write("\n".encode("utf-8")) - f.write("]\n".encode("utf-8")) + with BytesIO() as f: + with cls.JsonLinesWriter(f) as w: + for item in items: + cls._cleanup_entity(cls_, item) + w.write(item.to_json()) + data = f.getvalue() + hash_.update(data) + writer.writestr(filename, data) + + if tag_entities: + cls._add_tag(items, now.strftime(cls.export_tag)) + + return artifacts, hash_.hexdigest() @staticmethod - def _import(reader: ZipFile, user_id: str = None): - for file_info in reader.filelist: - full_name = splitext(file_info.orig_filename)[0] - print(f"Reading {reader.filename}:{full_name}...") - module_name, _, class_name = full_name.rpartition(".") - module = importlib.import_module(module_name) - cls_: Type[mongoengine.Document] = getattr(module, class_name) + def json_lines(file: BinaryIO): + for line in file: + clean = ( + line.decode("utf-8") + .rstrip("\r\n") + .strip() + .lstrip("[") + .rstrip(",]") + .strip() + ) + if not clean: + continue + yield clean - with reader.open(file_info) as f: - for item in tqdm( - f.readlines(), - desc=f"Writing {cls_.__name__.lower()}s into database", - unit="doc", - ): - item = ( - item.decode("utf-8") - .strip() - .lstrip("[") - .rstrip("]") - .rstrip(",") - .strip() - ) - if not item: - continue - doc = cls_.from_json(item) - if user_id is not None and hasattr(doc, "user"): - doc.user = user_id - doc.save(force_insert=True) + @classmethod + def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None): + """ + Import entities and events from the zip file + Start from entities since event import will require the tasks already in DB + """ + event_file_ending = cls.events_file_suffix + ".json" + entity_files = ( + fi + for fi in reader.filelist + if not fi.orig_filename.endswith(event_file_ending) + ) + event_files = ( + fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending) + ) + for files, reader_func in ( + (entity_files, cls._import_entity), + (event_files, cls._import_events), + ): + for file_info in files: + with reader.open(file_info) as f: + full_name = splitext(file_info.orig_filename)[0] + print(f"Reading {reader.filename}:{full_name}...") + reader_func(f, full_name, company_id, user_id) + + @classmethod + def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str): + module_name, _, class_name = full_name.rpartition(".") + module = importlib.import_module(module_name) + cls_: Type[mongoengine.Document] = getattr(module, class_name) + print(f"Writing {cls_.__name__.lower()}s into database") + for item in cls.json_lines(f): + doc = cls_.from_json(item, created=True) + if hasattr(doc, "user"): + doc.user = user_id + if hasattr(doc, "company"): + doc.company = company_id + if isinstance(doc, Project): + cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update( + set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}" + ) + doc.save() + if isinstance(doc, Task): + cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) + + @classmethod + def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _): + _, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_") + print(f"Writing events for task {task_id} into database") + for events_chunk in chunked_iter(cls.json_lines(f), 1000): + events = [json.loads(item) for item in events_chunk] + cls.event_bll.add_events( + company_id, events=events, worker="", allow_locked_tasks=True + ) diff --git a/server/server.py b/server/server.py index fa648bd..ec847f9 100644 --- a/server/server.py +++ b/server/server.py @@ -1,5 +1,6 @@ import atexit from argparse import ArgumentParser +from hashlib import md5 from flask import Flask, request, Response from flask_compress import Compress @@ -11,10 +12,11 @@ from apierrors.base import BaseError from bll.statistics.stats_reporter import StatisticsReporter from config import config from elastic.initialize import init_es_data -from mongo.initialize import init_mongo_data +from mongo.initialize import init_mongo_data, pre_populate_data from service_repo import ServiceRepo, APICall from service_repo.auth import AuthType from service_repo.errors import PathParsingError +from sync import distributed_lock from timing_context import TimingContext from updates import check_updates_thread from utilities import json @@ -33,8 +35,16 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json") database.initialize() -init_es_data() -init_mongo_data() +# build a key that uniquely identifies specific mongo instance +hosts_string = ";".join(sorted(database.get_hosts())) +key = "db_init_" + md5(hosts_string.encode()).hexdigest() +with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)): + print(key) + init_es_data() + empty_db = init_mongo_data() +if empty_db and config.get("apiserver.pre_populate.enabled", False): + pre_populate_data() + ServiceRepo.load("services") log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}") diff --git a/server/services/projects.py b/server/services/projects.py index 0becbc8..e72a282 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -185,6 +185,7 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None def get_all_ex(call: APICall): include_stats = call.data.get("include_stats") stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value) + allow_public = not call.data.get("non_public", False) if stats_for_state: try: @@ -200,7 +201,7 @@ def get_all_ex(call: APICall): company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, - allow_public=True, + allow_public=allow_public, ) conform_output_tags(call, projects) diff --git a/server/services/tasks.py b/server/services/tasks.py index f4d15dc..79d15bb 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -56,7 +56,7 @@ from database.model.task.task import ( ) from database.utils import get_fields, parse_from_call from service_repo import APICall, endpoint -from services.utils import conform_tag_fields, conform_output_tags +from services.utils import conform_tag_fields, conform_output_tags, validate_tags from timing_context import TimingContext from utilities import safe_get @@ -377,6 +377,7 @@ def create(call: APICall, company_id, req_model: CreateRequest): "tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse ) def clone_task(call: APICall, company_id, request: CloneRequest): + validate_tags(request.new_task_tags, request.new_task_system_tags) task = task_bll.clone_task( company_id=company_id, user_id=call.identity.user, diff --git a/server/sync.py b/server/sync.py new file mode 100644 index 0000000..a3e2fc8 --- /dev/null +++ b/server/sync.py @@ -0,0 +1,28 @@ +import time +from contextlib import contextmanager +from time import sleep + +from redis_manager import redman + +_redis = redman.connection("apiserver") + + +@contextmanager +def distributed_lock(name: str, timeout: int, max_wait: int = 0): + """ + Context manager that acquires a distributed lock on enter + and releases it on exit. The has a ttl equal to timeout seconds + If the lock can not be acquired for wait seconds (defaults to timeout * 2) + then the exception is thrown + """ + lock_name = f"dist_lock_{name}" + start = time.time() + max_wait = max_wait or timeout * 2 + while not _redis.set(lock_name, value="", ex=timeout, nx=True): + sleep(1) + if time.time() - start > max_wait: + raise Exception(f"Could not acquire {name} lock for {max_wait} seconds") + try: + yield + finally: + _redis.delete(lock_name)