From 7816b402bb1662045c2022574ff1655090dbc9e0 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 10 Aug 2020 08:53:41 +0300 Subject: [PATCH] Enhance ES7 initialization and migration support Support older task hyper-parameter migration on pre-population --- server/config/default/apiserver.conf | 13 +- server/config/default/hosts.conf | 4 +- server/config/info.py | 1 + server/elastic/apply_mappings.py | 55 ++++---- server/elastic/initialize.py | 62 ++++++--- .../elastic/mappings/{ => events}/events.json | 0 .../mappings/{ => events}/events_log.json | 0 .../mappings/{ => events}/events_plot.json | 0 .../events_training_debug_image.json | 0 .../mappings/{ => workers}/queue_metrics.json | 0 .../mappings/{ => workers}/worker_stats.json | 0 server/es_factory.py | 4 + server/mongo/initialize/__init__.py | 15 +-- server/mongo/initialize/migration.py | 39 +++--- server/mongo/initialize/pre_populate.py | 118 ++++++++++++++++-- server/schema/services/auth.conf | 6 +- server/server.py | 48 +++++-- server/services/auth.py | 15 ++- 18 files changed, 282 insertions(+), 98 deletions(-) rename server/elastic/mappings/{ => events}/events.json (100%) rename server/elastic/mappings/{ => events}/events_log.json (100%) rename server/elastic/mappings/{ => events}/events_plot.json (100%) rename server/elastic/mappings/{ => events}/events_training_debug_image.json (100%) rename server/elastic/mappings/{ => workers}/queue_metrics.json (100%) rename server/elastic/mappings/{ => workers}/worker_stats.json (100%) diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index 198c1fa..0cb9c56 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -35,7 +35,7 @@ # time in seconds to take an exclusive lock to init es and mongodb # not including the pre_populate - db_init_timout: 30 + db_init_timout: 120 mongo { # controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data @@ -47,6 +47,17 @@ } } + elastic { + probing { + # settings for inital probing of elastic connection + max_retries: 4 + timeout: 30 + } + upgrade_monitoring { + v16_migration_verification: true + } + } + auth { # verify user tokens verify_user_tokens: false diff --git a/server/config/default/hosts.conf b/server/config/default/hosts.conf index 51aa77c..ced74b9 100644 --- a/server/config/default/hosts.conf +++ b/server/config/default/hosts.conf @@ -4,7 +4,7 @@ elastic { args { timeout: 60 dead_timeout: 10 - max_retries: 5 + max_retries: 3 retry_on_timeout: true } index_version: "1" @@ -15,7 +15,7 @@ elastic { args { timeout: 60 dead_timeout: 10 - max_retries: 5 + max_retries: 3 retry_on_timeout: true } index_version: "1" diff --git a/server/config/info.py b/server/config/info.py index 60439ad..a970d89 100644 --- a/server/config/info.py +++ b/server/config/info.py @@ -44,3 +44,4 @@ def get_default_company(): missed_es_upgrade = False +es_connection_error = False diff --git a/server/elastic/apply_mappings.py b/server/elastic/apply_mappings.py index 3590e3c..8515b3b 100755 --- a/server/elastic/apply_mappings.py +++ b/server/elastic/apply_mappings.py @@ -5,56 +5,53 @@ Apply elasticsearch mappings to given hosts. import argparse import json from pathlib import Path +from typing import Optional, Sequence -import requests -from requests.adapters import HTTPAdapter -from requests.packages.urllib3.util.retry import Retry +from elasticsearch import Elasticsearch HERE = Path(__file__).resolve().parent -session = requests.Session() -adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5)) -session.mount("http://", adapter) +def apply_mappings_to_cluster( + hosts: Sequence, key: Optional[str] = None, es_args: dict = None +): + """Hosts maybe a sequence of strings or dicts in the form {"host": , "port": }""" -def get_template(host: str, template) -> dict: - url = f"{host}/_template/{template}" - res = session.get(url) - return res.json() - - -def apply_mappings_to_host(host: str): - def _send_mapping(f): + def _send_template(f): with f.open() as json_data: data = json.load(json_data) - url = f"{host}/_template/{f.stem}" - - session.delete(url) - r = session.post( - url, headers={"Content-Type": "application/json"}, data=json.dumps(data) - ) - return {"mapping": f.stem, "result": r.text} + template_name = f.stem + res = es.indices.put_template(template_name, body=data) + return {"mapping": template_name, "result": res} p = HERE / "mappings" - return [ - _send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json" - ] + if key: + files = (p / key).glob("*.json") + else: + files = p.glob("**/*.json") + + es = Elasticsearch(hosts=hosts, **(es_args or {})) + return [_send_template(f) for f in files] def parse_args(): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) - parser.add_argument("hosts", nargs="+") + parser.add_argument("--key", help="host key, e.g. events, datasets etc.") + parser.add_argument( + "--hosts", + nargs="+", + help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port", + ) return parser.parse_args() def main(): args = parse_args() - for host in args.hosts: - print(">>>>> Applying mapping to " + host) - res = apply_mappings_to_host(host) - print(res) + print(">>>>> Applying mapping to " + str(args.hosts)) + res = apply_mappings_to_cluster(args.hosts, args.key) + print(res) if __name__ == "__main__": diff --git a/server/elastic/initialize.py b/server/elastic/initialize.py index 5f4c63a..ea0777d 100644 --- a/server/elastic/initialize.py +++ b/server/elastic/initialize.py @@ -1,8 +1,10 @@ -from furl import furl +from time import sleep +from elasticsearch import Elasticsearch, exceptions + +import es_factory from config import config -from elastic.apply_mappings import apply_mappings_to_host, get_template -from es_factory import get_cluster_config +from elastic.apply_mappings import apply_mappings_to_cluster log = config.logger(__file__) @@ -15,22 +17,48 @@ class MissingElasticConfiguration(Exception): pass -def _url_from_host_conf(conf: dict) -> str: - return furl(scheme="http", host=conf["host"], port=conf["port"]).url +class ElasticConnectionError(Exception): + """ + Exception when could not connect to elastic during init + """ + + pass -def init_es_data() -> bool: - """Return True if the db was empty""" - hosts_config = get_cluster_config("events").get("hosts") - if not hosts_config: - raise MissingElasticConfiguration("for cluster 'events'") +def check_elastic_empty() -> bool: + """ + Check for elasticsearch connection + Use probing settings and not the default es cluster ones + so that we can handle correctly the connection rejects due to ES not fully started yet + :return: + """ + cluster_conf = es_factory.get_cluster_config("events") + max_retries = config.get("apiserver.elastic.probing.max_retries", 4) + timeout = config.get("apiserver.elastic.probing.timeout", 30) + for retry in range(max_retries): + try: + es = Elasticsearch(hosts=cluster_conf.get("hosts")) + return not es.indices.get_template(name="events*") + except exceptions.NotFoundError as ex: + log.error(ex) + return True + except exceptions.ConnectionError: + if retry >= max_retries - 1: + raise ElasticConnectionError() + log.warn( + f"Could not connect to es server. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec" + ) + sleep(timeout) - empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*") - for conf in hosts_config: - host = _url_from_host_conf(conf) - log.info(f"Applying mappings to host: {host}") - res = apply_mappings_to_host(host) +def init_es_data(): + for name in es_factory.get_all_cluster_names(): + cluster_conf = es_factory.get_cluster_config(name) + hosts_config = cluster_conf.get("hosts") + if not hosts_config: + raise MissingElasticConfiguration(f"for cluster '{name}'") + + log.info(f"Applying mappings to ES host: {hosts_config}") + args = cluster_conf.get("args", {}) + res = apply_mappings_to_cluster(hosts_config, name, es_args=args) log.info(res) - - return empty_db diff --git a/server/elastic/mappings/events.json b/server/elastic/mappings/events/events.json similarity index 100% rename from server/elastic/mappings/events.json rename to server/elastic/mappings/events/events.json diff --git a/server/elastic/mappings/events_log.json b/server/elastic/mappings/events/events_log.json similarity index 100% rename from server/elastic/mappings/events_log.json rename to server/elastic/mappings/events/events_log.json diff --git a/server/elastic/mappings/events_plot.json b/server/elastic/mappings/events/events_plot.json similarity index 100% rename from server/elastic/mappings/events_plot.json rename to server/elastic/mappings/events/events_plot.json diff --git a/server/elastic/mappings/events_training_debug_image.json b/server/elastic/mappings/events/events_training_debug_image.json similarity index 100% rename from server/elastic/mappings/events_training_debug_image.json rename to server/elastic/mappings/events/events_training_debug_image.json diff --git a/server/elastic/mappings/queue_metrics.json b/server/elastic/mappings/workers/queue_metrics.json similarity index 100% rename from server/elastic/mappings/queue_metrics.json rename to server/elastic/mappings/workers/queue_metrics.json diff --git a/server/elastic/mappings/worker_stats.json b/server/elastic/mappings/workers/worker_stats.json similarity index 100% rename from server/elastic/mappings/worker_stats.json rename to server/elastic/mappings/workers/worker_stats.json diff --git a/server/es_factory.py b/server/es_factory.py index 65573b1..e039cdc 100644 --- a/server/es_factory.py +++ b/server/es_factory.py @@ -65,6 +65,10 @@ def connect(cluster_name): return _instances[cluster_name] +def get_all_cluster_names(): + return list(config.get("hosts.elastic")) + + def get_cluster_config(cluster_name): """ Returns cluster config for the specified cluster path diff --git a/server/mongo/initialize/__init__.py b/server/mongo/initialize/__init__.py index 6607c33..40abe97 100644 --- a/server/mongo/initialize/__init__.py +++ b/server/mongo/initialize/__init__.py @@ -5,7 +5,7 @@ from config import config from config.info import get_default_company from database.model.auth import Role from service_repo.auth.fixed_user import FixedUser -from .migration import _apply_migrations +from .migration import _apply_migrations, check_mongo_empty, get_last_server_version from .pre_populate import PrePopulate from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user from .util import _ensure_company, _ensure_default_queue, _ensure_uuid @@ -26,10 +26,7 @@ def _pre_populate(company_id: str, zip_file: str): PrePopulate.import_from_zip( zip_file, - company_id="", - artifacts_path=config.get( - "apiserver.pre_populate.artifacts_path", None - ), + artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None), ) @@ -48,10 +45,12 @@ def pre_populate_data(): for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")): _pre_populate(company_id=get_default_company(), zip_file=zip_file) + PrePopulate.update_featured_projects_order() -def init_mongo_data() -> bool: + +def init_mongo_data(): try: - empty_dbs = _apply_migrations(log) + _apply_migrations(log) _ensure_uuid() @@ -86,7 +85,5 @@ def init_mongo_data() -> bool: ensure_fixed_user(user, log=log) except Exception as ex: log.error(f"Failed creating fixed user {user.name}: {ex}") - - return empty_dbs except Exception as ex: log.exception("Failed initializing mongodb") diff --git a/server/mongo/initialize/migration.py b/server/mongo/initialize/migration.py index f976200..b3e3bc5 100644 --- a/server/mongo/initialize/migration.py +++ b/server/mongo/initialize/migration.py @@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion migration_dir = Path(__file__).resolve().parent.with_name("migrations") -def _apply_migrations(log: Logger) -> bool: +def check_mongo_empty() -> bool: + return not all( + get_db(alias).collection_names() + for alias in database.utils.get_options(Database) + ) + + +def get_last_server_version() -> Version: + try: + previous_versions = sorted( + (Version(ver.num) for ver in DatabaseVersion.objects().only("num")), + reverse=True, + ) + except ValueError as ex: + raise ValueError(f"Invalid database version number encountered: {ex}") + + return previous_versions[0] if previous_versions else Version("0.0.0") + + +def _apply_migrations(log: Logger): """ Apply migrations as found in the migration dir. Returns a boolean indicating whether the database was empty prior to migration. @@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool: if not migration_dir.is_dir(): raise ValueError(f"Invalid migration dir {migration_dir}") - empty_dbs = not any( - get_db(alias).collection_names() - for alias in database.utils.get_options(Database) - ) - - try: - previous_versions = sorted( - (Version(ver.num) for ver in DatabaseVersion.objects().only("num")), - reverse=True, - ) - except ValueError as ex: - raise ValueError(f"Invalid database version number encountered: {ex}") - - last_version = previous_versions[0] if previous_versions else Version("0.0.0") + empty_dbs = check_mongo_empty() + last_version = get_last_server_version() try: new_scripts = { @@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool: ).save() log.info("Finished mongodb migrations") - - return empty_dbs diff --git a/server/mongo/initialize/pre_populate.py b/server/mongo/initialize/pre_populate.py index 0f93c7a..521fcc0 100644 --- a/server/mongo/initialize/pre_populate.py +++ b/server/mongo/initialize/pre_populate.py @@ -25,18 +25,26 @@ from typing import ( from urllib.parse import unquote, urlparse from zipfile import ZipFile, ZIP_BZIP2 +import dpath import mongoengine from boltons.iterutils import chunked_iter from furl import furl from mongoengine import Q from bll.event import EventBLL +from bll.task.param_utils import ( + split_param_name, + hyperparams_default_section, + hyperparams_legacy_type, +) from config import config +from config.info import get_default_company from database.model import EntityVisibility from database.model.model import Model from database.model.project import Project from database.model.task.task import Task, ArtifactModes, TaskStatus from database.utils import get_options +from tools import safe_get from utilities import json from .user import _ensure_backend_user @@ -47,6 +55,8 @@ class PrePopulate: export_tag_prefix = "Exported:" export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" metadata_filename = "metadata.json" + zip_args = dict(mode="w", compression=ZIP_BZIP2) + artifacts_ext = ".artifacts" class JsonLinesWriter: def __init__(self, file: BinaryIO): @@ -192,8 +202,7 @@ class PrePopulate: if old_path.is_file(): old_path.unlink() - zip_args = dict(mode="w", compression=ZIP_BZIP2) - with ZipFile(file, **zip_args) as zfile: + with ZipFile(file, **cls.zip_args) as zfile: if metadata: zfile.writestr(cls.metadata_filename, meta_str) artifacts = cls._export( @@ -209,8 +218,8 @@ class PrePopulate: artifacts = cls._filter_artifacts(artifacts) if artifacts and artifacts_path and os.path.isdir(artifacts_path): - artifacts_file = file_with_hash.with_suffix(".artifacts") - with ZipFile(artifacts_file, **zip_args) as zfile: + artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext) + with ZipFile(artifacts_file, **cls.zip_args) as zfile: cls._export_artifacts(zfile, artifacts, artifacts_path) created_files.append(str(artifacts_file)) @@ -227,8 +236,8 @@ class PrePopulate: def import_from_zip( cls, filename: str, - company_id: str, artifacts_path: str, + company_id: Optional[str] = None, user_id: str = "", user_name: str = "", ): @@ -238,6 +247,11 @@ class PrePopulate: try: with zfile.open(cls.metadata_filename) as f: metadata = json.loads(f.read()) + + meta_public = metadata.get("public") + if company_id is None and meta_public is not None: + company_id = "" if meta_public else get_default_company() + if not user_id: meta_user_id = metadata.get("user_id", "") meta_user_name = metadata.get("user_name", "") @@ -248,17 +262,101 @@ class PrePopulate: if not user_id: user_id, user_name = "__allegroai__", "Allegro.ai" - user_id = _ensure_backend_user(user_id, company_id, user_name) + # Make sure we won't end up with an invalid company ID + if company_id is None: + company_id = "" + + # Always use a public user for pre-populated data + user_id = _ensure_backend_user( + user_id=user_id, user_name=user_name, company_id="", + ) cls._import(zfile, company_id, user_id, metadata) if artifacts_path and os.path.isdir(artifacts_path): - artifacts_file = Path(filename).with_suffix(".artifacts") + artifacts_file = Path(filename).with_suffix(cls.artifacts_ext) if artifacts_file.is_file(): print(f"Unzipping artifacts into {artifacts_path}") with ZipFile(artifacts_file) as zfile: zfile.extractall(artifacts_path) + @classmethod + def upgrade_zip(cls, filename) -> Sequence: + hash_ = hashlib.md5() + task_file = cls._get_base_filename(Task) + ".json" + temp_file = Path("temp.zip") + file = Path(filename) + with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer: + for file_info in reader.filelist: + if file_info.orig_filename == task_file: + with reader.open(file_info) as f: + content = cls._upgrade_tasks(f) + else: + content = reader.read(file_info) + writer.writestr(file_info, content) + hash_.update(content) + + base_file_name, _, old_hash = file.stem.rpartition("_") + new_hash = hash_.hexdigest() + if old_hash == new_hash: + print(f"The file {filename} was not updated") + temp_file.unlink() + return [] + + new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}") + temp_file.replace(new_file) + upadated = [str(new_file)] + + artifacts_file = file.with_suffix(cls.artifacts_ext) + if artifacts_file.is_file(): + new_artifacts = new_file.with_suffix(cls.artifacts_ext) + artifacts_file.replace(new_artifacts) + upadated.append(str(new_artifacts)) + + print(f"File {str(file)} replaced with {str(new_file)}") + file.unlink() + + return upadated + + @staticmethod + def _upgrade_task_data(task_data: dict): + for old_param_field, new_param_field, default_section in ( + ("execution/parameters", "hyperparams", hyperparams_default_section), + ("execution/model_desc", "configuration", None), + ): + legacy = safe_get(task_data, old_param_field) + if not legacy: + continue + for full_name, value in legacy.items(): + section, name = split_param_name(full_name, default_section) + new_path = list(filter(None, (new_param_field, section, name))) + if not safe_get(task_data, new_path): + new_param = dict( + name=name, type=hyperparams_legacy_type, value=str(value) + ) + if section is not None: + new_param["section"] = section + dpath.new(task_data, new_path, new_param) + dpath.delete(task_data, old_param_field) + + @classmethod + def _upgrade_tasks(cls, f: BinaryIO) -> bytes: + """ + Build content array that contains fixed tasks from the passed file + For each task the old execution.parameters and model.design are + converted to the new structure. + The fix is done on Task objects (not the dictionary) so that + the fields are serialized back in the same order as they were in the original file + """ + with BytesIO() as temp: + with cls.JsonLinesWriter(temp) as w: + for line in cls.json_lines(f): + task_data = Task.from_json(line).to_proper_dict() + cls._upgrade_task_data(task_data) + new_task = Task(**task_data) + w.write(new_task.to_json()) + return temp.getvalue() + @classmethod def update_featured_projects_order(cls): featured_order = config.get("services.projects.featured_order", []) @@ -474,6 +572,10 @@ class PrePopulate: else: print(f"Artifact {full_path} not found") + @staticmethod + def _get_base_filename(cls_: type): + return f"{cls_.__module__}.{cls_.__name__}" + @classmethod def _export( cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False @@ -488,7 +590,7 @@ class PrePopulate: items = sorted(entities[cls_], key=attrgetter("id")) if not items: continue - base_filename = f"{cls_.__module__}.{cls_.__name__}" + base_filename = cls._get_base_filename(cls_) for item in items: artifacts.extend( cls._export_entity_related_data( diff --git a/server/schema/services/auth.conf b/server/schema/services/auth.conf index 3755dfb..79d89d0 100644 --- a/server/schema/services/auth.conf +++ b/server/schema/services/auth.conf @@ -328,8 +328,10 @@ fixed_users_mode { description: "Fixed users mode enabled" type: boolean } - migration_warning { - type: boolean + server_errors { + description: "Server initialization errors" + type: object + additionalProperties: True } } } diff --git a/server/server.py b/server/server.py index 4458364..6fdc581 100644 --- a/server/server.py +++ b/server/server.py @@ -5,14 +5,20 @@ from hashlib import md5 from flask import Flask, request, Response from flask_compress import Compress from flask_cors import CORS +from semantic_version import Version from werkzeug.exceptions import BadRequest import database from apierrors.base import BaseError from bll.statistics.stats_reporter import StatisticsReporter from config import config, info -from elastic.initialize import init_es_data -from mongo.initialize import init_mongo_data, pre_populate_data +from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError +from mongo.initialize import ( + init_mongo_data, + pre_populate_data, + check_mongo_empty, + get_last_server_version, +) from service_repo import ServiceRepo, APICall from service_repo.auth import AuthType from service_repo.errors import PathParsingError @@ -38,13 +44,37 @@ database.initialize() # build a key that uniquely identifies specific mongo instance hosts_string = ";".join(sorted(database.get_hosts())) key = "db_init_" + md5(hosts_string.encode()).hexdigest() -with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)): - empty_es = init_es_data() - empty_db = init_mongo_data() -if empty_es and not empty_db: - log.info(f"ES database seems not migrated") - info.missed_es_upgrade = True -if empty_db and config.get("apiserver.pre_populate.enabled", False): +with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)): + upgrade_monitoring = config.get( + "apiserver.elastic.upgrade_monitoring.v16_migration_verification", True + ) + try: + empty_es = check_elastic_empty() + except ElasticConnectionError as err: + if not upgrade_monitoring: + raise + log.error(err) + info.es_connection_error = True + + empty_db = check_mongo_empty() + if upgrade_monitoring: + if not empty_db and (info.es_connection_error or empty_es): + if get_last_server_version() < Version("0.16.0"): + log.info(f"ES database seems not migrated") + info.missed_es_upgrade = True + proceed_with_init = not (info.es_connection_error or info.missed_es_upgrade) + else: + proceed_with_init = True + + if proceed_with_init: + init_es_data() + init_mongo_data() + +if ( + proceed_with_init + and empty_db + and config.get("apiserver.pre_populate.enabled", False) +): pre_populate_data() diff --git a/server/services/auth.py b/server/services/auth.py index 8e0eb26..7229c02 100644 --- a/server/services/auth.py +++ b/server/services/auth.py @@ -176,12 +176,19 @@ def update(call, company_id, _): @endpoint("auth.fixed_users_mode") def fixed_users_mode(call: APICall, *_, **__): + server_errors = { + name: error + for name, error in zip( + ("missed_es_upgrade", "es_connection_error"), + (info.missed_es_upgrade, info.es_connection_error), + ) + if error + } + data = { "enabled": FixedUser.enabled(), - "migration_warning": info.missed_es_upgrade, - "guest": { - "enabled": FixedUser.guest_enabled(), - } + "guest": {"enabled": FixedUser.guest_enabled()}, + "server_errors": server_errors, } guest_user = FixedUser.get_guest_user() if guest_user: