Enhance ES7 initialization and migration support

Support older task hyper-parameter migration on pre-population
This commit is contained in:
allegroai 2020-08-10 08:53:41 +03:00
parent cd4ce30f7c
commit 7816b402bb
18 changed files with 282 additions and 98 deletions

View File

@ -35,7 +35,7 @@
# time in seconds to take an exclusive lock to init es and mongodb
# not including the pre_populate
db_init_timout: 30
db_init_timout: 120
mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
@ -47,6 +47,17 @@
}
}
elastic {
probing {
# settings for inital probing of elastic connection
max_retries: 4
timeout: 30
}
upgrade_monitoring {
v16_migration_verification: true
}
}
auth {
# verify user tokens
verify_user_tokens: false

View File

@ -4,7 +4,7 @@ elastic {
args {
timeout: 60
dead_timeout: 10
max_retries: 5
max_retries: 3
retry_on_timeout: true
}
index_version: "1"
@ -15,7 +15,7 @@ elastic {
args {
timeout: 60
dead_timeout: 10
max_retries: 5
max_retries: 3
retry_on_timeout: true
}
index_version: "1"

View File

@ -44,3 +44,4 @@ def get_default_company():
missed_es_upgrade = False
es_connection_error = False

View File

@ -5,56 +5,53 @@ Apply elasticsearch mappings to given hosts.
import argparse
import json
from pathlib import Path
from typing import Optional, Sequence
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from elasticsearch import Elasticsearch
HERE = Path(__file__).resolve().parent
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount("http://", adapter)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def get_template(host: str, template) -> dict:
url = f"{host}/_template/{template}"
res = session.get(url)
return res.json()
def apply_mappings_to_host(host: str):
def _send_mapping(f):
def _send_template(f):
with f.open() as json_data:
data = json.load(json_data)
url = f"{host}/_template/{f.stem}"
session.delete(url)
r = session.post(
url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
)
return {"mapping": f.stem, "result": r.text}
template_name = f.stem
res = es.indices.put_template(template_name, body=data)
return {"mapping": template_name, "result": res}
p = HERE / "mappings"
return [
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
]
if key:
files = (p / key).glob("*.json")
else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
return [_send_template(f) for f in files]
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("hosts", nargs="+")
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
parser.add_argument(
"--hosts",
nargs="+",
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
)
return parser.parse_args()
def main():
args = parse_args()
for host in args.hosts:
print(">>>>> Applying mapping to " + host)
res = apply_mappings_to_host(host)
print(res)
print(">>>>> Applying mapping to " + str(args.hosts))
res = apply_mappings_to_cluster(args.hosts, args.key)
print(res)
if __name__ == "__main__":

View File

@ -1,8 +1,10 @@
from furl import furl
from time import sleep
from elasticsearch import Elasticsearch, exceptions
import es_factory
from config import config
from elastic.apply_mappings import apply_mappings_to_host, get_template
from es_factory import get_cluster_config
from elastic.apply_mappings import apply_mappings_to_cluster
log = config.logger(__file__)
@ -15,22 +17,48 @@ class MissingElasticConfiguration(Exception):
pass
def _url_from_host_conf(conf: dict) -> str:
return furl(scheme="http", host=conf["host"], port=conf["port"]).url
class ElasticConnectionError(Exception):
"""
Exception when could not connect to elastic during init
"""
pass
def init_es_data() -> bool:
"""Return True if the db was empty"""
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
def check_elastic_empty() -> bool:
"""
Check for elasticsearch connection
Use probing settings and not the default es cluster ones
so that we can handle correctly the connection rejects due to ES not fully started yet
:return:
"""
cluster_conf = es_factory.get_cluster_config("events")
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
timeout = config.get("apiserver.elastic.probing.timeout", 30)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
except exceptions.ConnectionError:
if retry >= max_retries - 1:
raise ElasticConnectionError()
log.warn(
f"Could not connect to es server. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
)
sleep(timeout)
empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*")
for conf in hosts_config:
host = _url_from_host_conf(conf)
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
def init_es_data():
for name in es_factory.get_all_cluster_names():
cluster_conf = es_factory.get_cluster_config(name)
hosts_config = cluster_conf.get("hosts")
if not hosts_config:
raise MissingElasticConfiguration(f"for cluster '{name}'")
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
log.info(res)
return empty_db

View File

@ -65,6 +65,10 @@ def connect(cluster_name):
return _instances[cluster_name]
def get_all_cluster_names():
return list(config.get("hosts.elastic"))
def get_cluster_config(cluster_name):
"""
Returns cluster config for the specified cluster path

View File

@ -5,7 +5,7 @@ from config import config
from config.info import get_default_company
from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
@ -26,10 +26,7 @@ def _pre_populate(company_id: str, zip_file: str):
PrePopulate.import_from_zip(
zip_file,
company_id="",
artifacts_path=config.get(
"apiserver.pre_populate.artifacts_path", None
),
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
)
@ -48,10 +45,12 @@ def pre_populate_data():
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
PrePopulate.update_featured_projects_order()
def init_mongo_data() -> bool:
def init_mongo_data():
try:
empty_dbs = _apply_migrations(log)
_apply_migrations(log)
_ensure_uuid()
@ -86,7 +85,5 @@ def init_mongo_data() -> bool:
ensure_fixed_user(user, log=log)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
return empty_dbs
except Exception as ex:
log.exception("Failed initializing mongodb")

View File

@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def _apply_migrations(log: Logger) -> bool:
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
def get_last_server_version() -> Version:
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
return previous_versions[0] if previous_versions else Version("0.0.0")
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration.
@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = not any(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
try:
new_scripts = {
@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
).save()
log.info("Finished mongodb migrations")
return empty_dbs

View File

@ -25,18 +25,26 @@ from typing import (
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from boltons.iterutils import chunked_iter
from furl import furl
from mongoengine import Q
from bll.event import EventBLL
from bll.task.param_utils import (
split_param_name,
hyperparams_default_section,
hyperparams_legacy_type,
)
from config import config
from config.info import get_default_company
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options
from tools import safe_get
from utilities import json
from .user import _ensure_backend_user
@ -47,6 +55,8 @@ class PrePopulate:
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
@ -192,8 +202,7 @@ class PrePopulate:
if old_path.is_file():
old_path.unlink()
zip_args = dict(mode="w", compression=ZIP_BZIP2)
with ZipFile(file, **zip_args) as zfile:
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
artifacts = cls._export(
@ -209,8 +218,8 @@ class PrePopulate:
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(".artifacts")
with ZipFile(artifacts_file, **zip_args) as zfile:
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file))
@ -227,8 +236,8 @@ class PrePopulate:
def import_from_zip(
cls,
filename: str,
company_id: str,
artifacts_path: str,
company_id: Optional[str] = None,
user_id: str = "",
user_name: str = "",
):
@ -238,6 +247,11 @@ class PrePopulate:
try:
with zfile.open(cls.metadata_filename) as f:
metadata = json.loads(f.read())
meta_public = metadata.get("public")
if company_id is None and meta_public is not None:
company_id = "" if meta_public else get_default_company()
if not user_id:
meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "")
@ -248,17 +262,101 @@ class PrePopulate:
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
user_id = _ensure_backend_user(user_id, company_id, user_name)
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
user_id = _ensure_backend_user(
user_id=user_id, user_name=user_name, company_id="",
)
cls._import(zfile, company_id, user_id, metadata)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(".artifacts")
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path)
@classmethod
def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json"
temp_file = Path("temp.zip")
file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
for file_info in reader.filelist:
if file_info.orig_filename == task_file:
with reader.open(file_info) as f:
content = cls._upgrade_tasks(f)
else:
content = reader.read(file_info)
writer.writestr(file_info, content)
hash_.update(content)
base_file_name, _, old_hash = file.stem.rpartition("_")
new_hash = hash_.hexdigest()
if old_hash == new_hash:
print(f"The file {filename} was not updated")
temp_file.unlink()
return []
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
temp_file.replace(new_file)
upadated = [str(new_file)]
artifacts_file = file.with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
artifacts_file.replace(new_artifacts)
upadated.append(str(new_artifacts))
print(f"File {str(file)} replaced with {str(new_file)}")
file.unlink()
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
the fields are serialized back in the same order as they were in the original file
"""
with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data)
new_task = Task(**task_data)
w.write(new_task.to_json())
return temp.getvalue()
@classmethod
def update_featured_projects_order(cls):
featured_order = config.get("services.projects.featured_order", [])
@ -474,6 +572,10 @@ class PrePopulate:
else:
print(f"Artifact {full_path} not found")
@staticmethod
def _get_base_filename(cls_: type):
return f"{cls_.__module__}.{cls_.__name__}"
@classmethod
def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
@ -488,7 +590,7 @@ class PrePopulate:
items = sorted(entities[cls_], key=attrgetter("id"))
if not items:
continue
base_filename = f"{cls_.__module__}.{cls_.__name__}"
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(

View File

@ -328,8 +328,10 @@ fixed_users_mode {
description: "Fixed users mode enabled"
type: boolean
}
migration_warning {
type: boolean
server_errors {
description: "Server initialization errors"
type: object
additionalProperties: True
}
}
}

View File

@ -5,14 +5,20 @@ from hashlib import md5
from flask import Flask, request, Response
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from werkzeug.exceptions import BadRequest
import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config, info
from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data, pre_populate_data
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
@ -38,13 +44,37 @@ database.initialize()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)):
empty_es = init_es_data()
empty_db = init_mongo_data()
if empty_es and not empty_db:
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if empty_db and config.get("apiserver.pre_populate.enabled", False):
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if upgrade_monitoring:
if not empty_db and (info.es_connection_error or empty_es):
if get_last_server_version() < Version("0.16.0"):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
proceed_with_init = not (info.es_connection_error or info.missed_es_upgrade)
else:
proceed_with_init = True
if proceed_with_init:
init_es_data()
init_mongo_data()
if (
proceed_with_init
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()

View File

@ -176,12 +176,19 @@ def update(call, company_id, _):
@endpoint("auth.fixed_users_mode")
def fixed_users_mode(call: APICall, *_, **__):
server_errors = {
name: error
for name, error in zip(
("missed_es_upgrade", "es_connection_error"),
(info.missed_es_upgrade, info.es_connection_error),
)
if error
}
data = {
"enabled": FixedUser.enabled(),
"migration_warning": info.missed_es_upgrade,
"guest": {
"enabled": FixedUser.guest_enabled(),
}
"guest": {"enabled": FixedUser.guest_enabled()},
"server_errors": server_errors,
}
guest_user = FixedUser.get_guest_user()
if guest_user: