Introduce app startup sequence

This commit is contained in:
allegroai 2021-01-05 16:25:17 +02:00
parent df65e1c7ad
commit c67a56eb8d
10 changed files with 273 additions and 231 deletions

155
apiserver/app_routes.py Normal file
View File

@ -0,0 +1,155 @@
from flask import Flask, request, Response
from werkzeug.exceptions import BadRequest
from apiserver.apierrors.base import BaseError
from apiserver.config import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.errors import PathParsingError
from apiserver.timing_context import TimingContext
from apiserver.utilities import json
log = config.logger(__file__)
def before_app_first_request():
pass
def before_request():
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
if request.path.startswith("/rq"):
return
try:
call = create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
headers = {}
if call.result.filename:
headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers:
headers.update(call.result.headers)
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
if call.result.cookies:
for key, value in call.result.cookies.items():
if value is None:
response.set_cookie(
key, "", expires=0, **config.get("apiserver.auth.cookies")
)
else:
response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def update_call_data(call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def create_api_call(req):
call = None
try:
# Parse the request path
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
host=req.host,
)
# Update call data from request
with TimingContext("preprocess", "update_call_data"):
update_call_data(call, req)
except PathParsingError as ex:
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = _call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = _call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
return call
def register_routes(app: Flask):
app.before_first_request(before_app_first_request)
app.before_request(before_request)

98
apiserver/app_sequence.py Normal file
View File

@ -0,0 +1,98 @@
import atexit
from hashlib import md5
from flask import Flask
from semantic_version import Version
import database
from bll.statistics.stats_reporter import StatisticsReporter
from config import config, info
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo
from sync import distributed_lock
from updates import check_updates_thread
from utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
class AppSequence:
def __init__(self, app: Flask):
self.app = app
def start(self):
log.info("################ API Server initializing #####################")
self._configure()
self._init_dbs()
self._load_services()
self._start_worker()
atexit.register(self._on_worker_stop)
def _configure(self):
self.app.config["SECRET_KEY"] = config.get(
"secure.http.session_secret.apiserver"
)
self.app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get(
"apiserver.pretty_json"
)
def _init_dbs(self):
database.initialize()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if (
upgrade_monitoring
and not empty_db
and (info.es_connection_error or empty_es)
and get_last_server_version() < Version("0.16.0")
):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if info.es_connection_error and not info.missed_es_upgrade:
raise Exception(
"Error starting server: failed connecting to ElasticSearch service"
)
if not info.missed_es_upgrade:
init_es_data()
init_mongo_data()
if (
not info.missed_es_upgrade
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()
def _load_services(self):
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
def _start_worker(self):
check_updates_thread.start()
StatisticsReporter.start()
def _on_worker_stop(self):
ThreadsManager.terminating = True

View File

@ -12,7 +12,7 @@ from boltons.dictutils import subdict
from pyhocon import ConfigFactory from pyhocon import ConfigFactory
from config import config from config import config
from service_repo.base import PartialVersion from utilities.partial_version import PartialVersion
HERE = Path(__file__) HERE = Path(__file__)

View File

@ -1,237 +1,18 @@
import atexit
from argparse import ArgumentParser from argparse import ArgumentParser
from hashlib import md5
from flask import Flask, request, Response from flask import Flask
from flask_compress import Compress from flask_compress import Compress
from flask_cors import CORS from flask_cors import CORS
from semantic_version import Version
from werkzeug.exceptions import BadRequest
import database from app_routes import register_routes
from apierrors.base import BaseError from app_sequence import AppSequence
from bll.statistics.stats_reporter import StatisticsReporter from config import config
from config import config, info
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from sync import distributed_lock
from timing_context import TimingContext
from updates import check_updates_thread
from utilities import json
from utilities.threads_manager import ThreadsManager
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors")) CORS(app, **config.get("apiserver.cors"))
Compress(app) Compress(app)
register_routes(app)
log = config.logger(__file__) AppSequence(app).start()
log.info("################ API Server initializing #####################")
app.config["SECRET_KEY"] = config.get("secure.http.session_secret.apiserver")
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
database.initialize()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if (
upgrade_monitoring
and not empty_db
and (info.es_connection_error or empty_es)
and get_last_server_version() < Version("0.16.0")
):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if info.es_connection_error and not info.missed_es_upgrade:
raise Exception(
"Error starting server: failed connecting to ElasticSearch service"
)
if not info.missed_es_upgrade:
init_es_data()
init_mongo_data()
if (
not info.missed_es_upgrade
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
check_updates_thread.start()
StatisticsReporter.start()
def graceful_shutdown():
ThreadsManager.terminating = True
atexit.register(graceful_shutdown)
@app.before_first_request
def before_app_first_request():
pass
@app.before_request
def before_request():
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
try:
call = create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
headers = {}
if call.result.filename:
headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers:
headers.update(call.result.headers)
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
if call.result.cookies:
for key, value in call.result.cookies.items():
if value is None:
response.set_cookie(key, "", expires=0)
else:
response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def update_call_data(call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def create_api_call(req):
call = None
try:
# Parse the request path
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
)
# Update call data from request
with TimingContext("preprocess", "update_call_data"):
update_call_data(call, req)
except PathParsingError as ex:
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = _call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = _call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
return call
# =================== MAIN ======================= # =================== MAIN =======================

View File

@ -9,9 +9,9 @@ from six import string_types
import database import database
from timing_context import TimingContext, TimingStats from timing_context import TimingContext, TimingStats
from utilities import json from utilities import json
from utilities.partial_version import PartialVersion
from .auth import Identity from .auth import Identity
from .auth import Payload as AuthPayload from .auth import Payload as AuthPayload
from .base import PartialVersion
from .errors import CallParsingError from .errors import CallParsingError
from .schema_validator import SchemaValidator from .schema_validator import SchemaValidator
@ -305,6 +305,7 @@ class APICall(DataContainer):
headers=None, headers=None,
files=None, files=None,
trx=None, trx=None,
**kwargs,
): ):
super(APICall, self).__init__(data=data, batched_data=batched_data) super(APICall, self).__init__(data=data, batched_data=batched_data)

View File

@ -5,8 +5,8 @@ from jsonmodels import models
from jsonmodels.errors import FieldNotSupported from jsonmodels.errors import FieldNotSupported
from schema import schema from schema import schema
from utilities.partial_version import PartialVersion
from .apicall import APICall from .apicall import APICall
from .base import PartialVersion
from .schema_validator import SchemaValidator from .schema_validator import SchemaValidator
EndpointFunc = Callable[[APICall, Text, models.Base], None] EndpointFunc = Callable[[APICall, Text, models.Base], None]

View File

@ -11,7 +11,7 @@ from apierrors import APIError
from apierrors.errors.bad_request import RequestPathHasInvalidVersion from apierrors.errors.bad_request import RequestPathHasInvalidVersion
from api_version import __version__ as _api_version_ from api_version import __version__ as _api_version_
from config import config from config import config
from service_repo.base import PartialVersion from utilities.partial_version import PartialVersion
from .apicall import APICall from .apicall import APICall
from .endpoint import Endpoint from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError from .errors import MalformedPathError, InvalidVersionError, CallFailedError

View File

@ -69,9 +69,9 @@ from database.model.task.task import (
) )
from database.utils import get_fields, parse_from_call from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint from service_repo import APICall, endpoint
from service_repo.base import PartialVersion
from services.utils import conform_tag_fields, conform_output_tags, validate_tags from services.utils import conform_tag_fields, conform_output_tags, validate_tags
from timing_context import TimingContext from timing_context import TimingContext
from utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields()) task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script)) task_script_fields = set(get_fields(Script))

View File

@ -5,7 +5,7 @@ from apimodels.organization import Filter
from database.model.base import GetMixin from database.model.base import GetMixin
from database.utils import partition_tags from database.utils import partition_tags
from service_repo import APICall from service_repo import APICall
from service_repo.base import PartialVersion from utilities.partial_version import PartialVersion
def get_tags_filter_dictionary(input_: Filter) -> dict: def get_tags_filter_dictionary(input_: Filter) -> dict:

View File

@ -0,0 +1,7 @@
from semantic_version import Version
class PartialVersion(Version):
def __init__(self, version_string: str):
assert isinstance(version_string, str)
super().__init__(version_string, partial=True)