Enhance ES7 initialization and migration support

Support older task hyper-parameter migration on pre-population
This commit is contained in:
allegroai 2020-08-10 08:53:41 +03:00
parent cd4ce30f7c
commit 7816b402bb
18 changed files with 282 additions and 98 deletions

View File

@ -35,7 +35,7 @@
# time in seconds to take an exclusive lock to init es and mongodb # time in seconds to take an exclusive lock to init es and mongodb
# not including the pre_populate # not including the pre_populate
db_init_timout: 30 db_init_timout: 120
mongo { mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data # controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
@ -47,6 +47,17 @@
} }
} }
elastic {
probing {
# settings for inital probing of elastic connection
max_retries: 4
timeout: 30
}
upgrade_monitoring {
v16_migration_verification: true
}
}
auth { auth {
# verify user tokens # verify user tokens
verify_user_tokens: false verify_user_tokens: false

View File

@ -4,7 +4,7 @@ elastic {
args { args {
timeout: 60 timeout: 60
dead_timeout: 10 dead_timeout: 10
max_retries: 5 max_retries: 3
retry_on_timeout: true retry_on_timeout: true
} }
index_version: "1" index_version: "1"
@ -15,7 +15,7 @@ elastic {
args { args {
timeout: 60 timeout: 60
dead_timeout: 10 dead_timeout: 10
max_retries: 5 max_retries: 3
retry_on_timeout: true retry_on_timeout: true
} }
index_version: "1" index_version: "1"

View File

@ -44,3 +44,4 @@ def get_default_company():
missed_es_upgrade = False missed_es_upgrade = False
es_connection_error = False

View File

@ -5,56 +5,53 @@ Apply elasticsearch mappings to given hosts.
import argparse import argparse
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence
import requests from elasticsearch import Elasticsearch
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
HERE = Path(__file__).resolve().parent HERE = Path(__file__).resolve().parent
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount("http://", adapter)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def get_template(host: str, template) -> dict: def _send_template(f):
url = f"{host}/_template/{template}"
res = session.get(url)
return res.json()
def apply_mappings_to_host(host: str):
def _send_mapping(f):
with f.open() as json_data: with f.open() as json_data:
data = json.load(json_data) data = json.load(json_data)
url = f"{host}/_template/{f.stem}" template_name = f.stem
res = es.indices.put_template(template_name, body=data)
session.delete(url) return {"mapping": template_name, "result": res}
r = session.post(
url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
)
return {"mapping": f.stem, "result": r.text}
p = HERE / "mappings" p = HERE / "mappings"
return [ if key:
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json" files = (p / key).glob("*.json")
] else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
return [_send_template(f) for f in files]
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
) )
parser.add_argument("hosts", nargs="+") parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
parser.add_argument(
"--hosts",
nargs="+",
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
)
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
for host in args.hosts: print(">>>>> Applying mapping to " + str(args.hosts))
print(">>>>> Applying mapping to " + host) res = apply_mappings_to_cluster(args.hosts, args.key)
res = apply_mappings_to_host(host) print(res)
print(res)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,8 +1,10 @@
from furl import furl from time import sleep
from elasticsearch import Elasticsearch, exceptions
import es_factory
from config import config from config import config
from elastic.apply_mappings import apply_mappings_to_host, get_template from elastic.apply_mappings import apply_mappings_to_cluster
from es_factory import get_cluster_config
log = config.logger(__file__) log = config.logger(__file__)
@ -15,22 +17,48 @@ class MissingElasticConfiguration(Exception):
pass pass
def _url_from_host_conf(conf: dict) -> str: class ElasticConnectionError(Exception):
return furl(scheme="http", host=conf["host"], port=conf["port"]).url """
Exception when could not connect to elastic during init
"""
pass
def init_es_data() -> bool: def check_elastic_empty() -> bool:
"""Return True if the db was empty""" """
hosts_config = get_cluster_config("events").get("hosts") Check for elasticsearch connection
if not hosts_config: Use probing settings and not the default es cluster ones
raise MissingElasticConfiguration("for cluster 'events'") so that we can handle correctly the connection rejects due to ES not fully started yet
:return:
"""
cluster_conf = es_factory.get_cluster_config("events")
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
timeout = config.get("apiserver.elastic.probing.timeout", 30)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
except exceptions.ConnectionError:
if retry >= max_retries - 1:
raise ElasticConnectionError()
log.warn(
f"Could not connect to es server. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
)
sleep(timeout)
empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*")
for conf in hosts_config: def init_es_data():
host = _url_from_host_conf(conf) for name in es_factory.get_all_cluster_names():
log.info(f"Applying mappings to host: {host}") cluster_conf = es_factory.get_cluster_config(name)
res = apply_mappings_to_host(host) hosts_config = cluster_conf.get("hosts")
if not hosts_config:
raise MissingElasticConfiguration(f"for cluster '{name}'")
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
log.info(res) log.info(res)
return empty_db

View File

@ -65,6 +65,10 @@ def connect(cluster_name):
return _instances[cluster_name] return _instances[cluster_name]
def get_all_cluster_names():
return list(config.get("hosts.elastic"))
def get_cluster_config(cluster_name): def get_cluster_config(cluster_name):
""" """
Returns cluster config for the specified cluster path Returns cluster config for the specified cluster path

View File

@ -5,7 +5,7 @@ from config import config
from config.info import get_default_company from config.info import get_default_company
from database.model.auth import Role from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
@ -26,10 +26,7 @@ def _pre_populate(company_id: str, zip_file: str):
PrePopulate.import_from_zip( PrePopulate.import_from_zip(
zip_file, zip_file,
company_id="", artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
artifacts_path=config.get(
"apiserver.pre_populate.artifacts_path", None
),
) )
@ -48,10 +45,12 @@ def pre_populate_data():
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")): for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
_pre_populate(company_id=get_default_company(), zip_file=zip_file) _pre_populate(company_id=get_default_company(), zip_file=zip_file)
PrePopulate.update_featured_projects_order()
def init_mongo_data() -> bool:
def init_mongo_data():
try: try:
empty_dbs = _apply_migrations(log) _apply_migrations(log)
_ensure_uuid() _ensure_uuid()
@ -86,7 +85,5 @@ def init_mongo_data() -> bool:
ensure_fixed_user(user, log=log) ensure_fixed_user(user, log=log)
except Exception as ex: except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}") log.error(f"Failed creating fixed user {user.name}: {ex}")
return empty_dbs
except Exception as ex: except Exception as ex:
log.exception("Failed initializing mongodb") log.exception("Failed initializing mongodb")

View File

@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations") migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def _apply_migrations(log: Logger) -> bool: def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
def get_last_server_version() -> Version:
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
return previous_versions[0] if previous_versions else Version("0.0.0")
def _apply_migrations(log: Logger):
""" """
Apply migrations as found in the migration dir. Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration. Returns a boolean indicating whether the database was empty prior to migration.
@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
if not migration_dir.is_dir(): if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}") raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = not any( empty_dbs = check_mongo_empty()
get_db(alias).collection_names() last_version = get_last_server_version()
for alias in database.utils.get_options(Database)
)
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
try: try:
new_scripts = { new_scripts = {
@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
).save() ).save()
log.info("Finished mongodb migrations") log.info("Finished mongodb migrations")
return empty_dbs

View File

@ -25,18 +25,26 @@ from typing import (
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2 from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine import mongoengine
from boltons.iterutils import chunked_iter from boltons.iterutils import chunked_iter
from furl import furl from furl import furl
from mongoengine import Q from mongoengine import Q
from bll.event import EventBLL from bll.event import EventBLL
from bll.task.param_utils import (
split_param_name,
hyperparams_default_section,
hyperparams_legacy_type,
)
from config import config from config import config
from config.info import get_default_company
from database.model import EntityVisibility from database.model import EntityVisibility
from database.model.model import Model from database.model.model import Model
from database.model.project import Project from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options from database.utils import get_options
from tools import safe_get
from utilities import json from utilities import json
from .user import _ensure_backend_user from .user import _ensure_backend_user
@ -47,6 +55,8 @@ class PrePopulate:
export_tag_prefix = "Exported:" export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json" metadata_filename = "metadata.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
class JsonLinesWriter: class JsonLinesWriter:
def __init__(self, file: BinaryIO): def __init__(self, file: BinaryIO):
@ -192,8 +202,7 @@ class PrePopulate:
if old_path.is_file(): if old_path.is_file():
old_path.unlink() old_path.unlink()
zip_args = dict(mode="w", compression=ZIP_BZIP2) with ZipFile(file, **cls.zip_args) as zfile:
with ZipFile(file, **zip_args) as zfile:
if metadata: if metadata:
zfile.writestr(cls.metadata_filename, meta_str) zfile.writestr(cls.metadata_filename, meta_str)
artifacts = cls._export( artifacts = cls._export(
@ -209,8 +218,8 @@ class PrePopulate:
artifacts = cls._filter_artifacts(artifacts) artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path): if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(".artifacts") artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **zip_args) as zfile: with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path) cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file)) created_files.append(str(artifacts_file))
@ -227,8 +236,8 @@ class PrePopulate:
def import_from_zip( def import_from_zip(
cls, cls,
filename: str, filename: str,
company_id: str,
artifacts_path: str, artifacts_path: str,
company_id: Optional[str] = None,
user_id: str = "", user_id: str = "",
user_name: str = "", user_name: str = "",
): ):
@ -238,6 +247,11 @@ class PrePopulate:
try: try:
with zfile.open(cls.metadata_filename) as f: with zfile.open(cls.metadata_filename) as f:
metadata = json.loads(f.read()) metadata = json.loads(f.read())
meta_public = metadata.get("public")
if company_id is None and meta_public is not None:
company_id = "" if meta_public else get_default_company()
if not user_id: if not user_id:
meta_user_id = metadata.get("user_id", "") meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "") meta_user_name = metadata.get("user_name", "")
@ -248,17 +262,101 @@ class PrePopulate:
if not user_id: if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai" user_id, user_name = "__allegroai__", "Allegro.ai"
user_id = _ensure_backend_user(user_id, company_id, user_name) # Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
user_id = _ensure_backend_user(
user_id=user_id, user_name=user_name, company_id="",
)
cls._import(zfile, company_id, user_id, metadata) cls._import(zfile, company_id, user_id, metadata)
if artifacts_path and os.path.isdir(artifacts_path): if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(".artifacts") artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
if artifacts_file.is_file(): if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}") print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile: with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path) zfile.extractall(artifacts_path)
@classmethod
def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json"
temp_file = Path("temp.zip")
file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
for file_info in reader.filelist:
if file_info.orig_filename == task_file:
with reader.open(file_info) as f:
content = cls._upgrade_tasks(f)
else:
content = reader.read(file_info)
writer.writestr(file_info, content)
hash_.update(content)
base_file_name, _, old_hash = file.stem.rpartition("_")
new_hash = hash_.hexdigest()
if old_hash == new_hash:
print(f"The file {filename} was not updated")
temp_file.unlink()
return []
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
temp_file.replace(new_file)
upadated = [str(new_file)]
artifacts_file = file.with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
artifacts_file.replace(new_artifacts)
upadated.append(str(new_artifacts))
print(f"File {str(file)} replaced with {str(new_file)}")
file.unlink()
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
the fields are serialized back in the same order as they were in the original file
"""
with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data)
new_task = Task(**task_data)
w.write(new_task.to_json())
return temp.getvalue()
@classmethod @classmethod
def update_featured_projects_order(cls): def update_featured_projects_order(cls):
featured_order = config.get("services.projects.featured_order", []) featured_order = config.get("services.projects.featured_order", [])
@ -474,6 +572,10 @@ class PrePopulate:
else: else:
print(f"Artifact {full_path} not found") print(f"Artifact {full_path} not found")
@staticmethod
def _get_base_filename(cls_: type):
return f"{cls_.__module__}.{cls_.__name__}"
@classmethod @classmethod
def _export( def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
@ -488,7 +590,7 @@ class PrePopulate:
items = sorted(entities[cls_], key=attrgetter("id")) items = sorted(entities[cls_], key=attrgetter("id"))
if not items: if not items:
continue continue
base_filename = f"{cls_.__module__}.{cls_.__name__}" base_filename = cls._get_base_filename(cls_)
for item in items: for item in items:
artifacts.extend( artifacts.extend(
cls._export_entity_related_data( cls._export_entity_related_data(

View File

@ -328,8 +328,10 @@ fixed_users_mode {
description: "Fixed users mode enabled" description: "Fixed users mode enabled"
type: boolean type: boolean
} }
migration_warning { server_errors {
type: boolean description: "Server initialization errors"
type: object
additionalProperties: True
} }
} }
} }

View File

@ -5,14 +5,20 @@ from hashlib import md5
from flask import Flask, request, Response from flask import Flask, request, Response
from flask_compress import Compress from flask_compress import Compress
from flask_cors import CORS from flask_cors import CORS
from semantic_version import Version
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
import database import database
from apierrors.base import BaseError from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter from bll.statistics.stats_reporter import StatisticsReporter
from config import config, info from config import config, info
from elastic.initialize import init_es_data from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import init_mongo_data, pre_populate_data from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo, APICall from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType from service_repo.auth import AuthType
from service_repo.errors import PathParsingError from service_repo.errors import PathParsingError
@ -38,13 +44,37 @@ database.initialize()
# build a key that uniquely identifies specific mongo instance # build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts())) hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest() key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)): with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
empty_es = init_es_data() upgrade_monitoring = config.get(
empty_db = init_mongo_data() "apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
if empty_es and not empty_db: )
log.info(f"ES database seems not migrated") try:
info.missed_es_upgrade = True empty_es = check_elastic_empty()
if empty_db and config.get("apiserver.pre_populate.enabled", False): except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if upgrade_monitoring:
if not empty_db and (info.es_connection_error or empty_es):
if get_last_server_version() < Version("0.16.0"):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
proceed_with_init = not (info.es_connection_error or info.missed_es_upgrade)
else:
proceed_with_init = True
if proceed_with_init:
init_es_data()
init_mongo_data()
if (
proceed_with_init
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data() pre_populate_data()

View File

@ -176,12 +176,19 @@ def update(call, company_id, _):
@endpoint("auth.fixed_users_mode") @endpoint("auth.fixed_users_mode")
def fixed_users_mode(call: APICall, *_, **__): def fixed_users_mode(call: APICall, *_, **__):
server_errors = {
name: error
for name, error in zip(
("missed_es_upgrade", "es_connection_error"),
(info.missed_es_upgrade, info.es_connection_error),
)
if error
}
data = { data = {
"enabled": FixedUser.enabled(), "enabled": FixedUser.enabled(),
"migration_warning": info.missed_es_upgrade, "guest": {"enabled": FixedUser.guest_enabled()},
"guest": { "server_errors": server_errors,
"enabled": FixedUser.guest_enabled(),
}
} }
guest_user = FixedUser.get_guest_user() guest_user = FixedUser.get_guest_user()
if guest_user: if guest_user: