Introduce app startup sequence

This commit is contained in:
allegroai 2021-01-05 16:25:17 +02:00
parent df65e1c7ad
commit c67a56eb8d
10 changed files with 273 additions and 231 deletions

155
apiserver/app_routes.py Normal file
View File

@ -0,0 +1,155 @@
from flask import Flask, request, Response
from werkzeug.exceptions import BadRequest
from apiserver.apierrors.base import BaseError
from apiserver.config import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.errors import PathParsingError
from apiserver.timing_context import TimingContext
from apiserver.utilities import json
log = config.logger(__file__)
def before_app_first_request():
pass
def before_request():
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
if request.path.startswith("/rq"):
return
try:
call = create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
headers = {}
if call.result.filename:
headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers:
headers.update(call.result.headers)
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
if call.result.cookies:
for key, value in call.result.cookies.items():
if value is None:
response.set_cookie(
key, "", expires=0, **config.get("apiserver.auth.cookies")
)
else:
response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def update_call_data(call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def create_api_call(req):
call = None
try:
# Parse the request path
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
host=req.host,
)
# Update call data from request
with TimingContext("preprocess", "update_call_data"):
update_call_data(call, req)
except PathParsingError as ex:
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = _call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = _call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
return call
def register_routes(app: Flask):
app.before_first_request(before_app_first_request)
app.before_request(before_request)

98
apiserver/app_sequence.py Normal file
View File

@ -0,0 +1,98 @@
import atexit
from hashlib import md5
from flask import Flask
from semantic_version import Version
import database
from bll.statistics.stats_reporter import StatisticsReporter
from config import config, info
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo
from sync import distributed_lock
from updates import check_updates_thread
from utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
class AppSequence:
def __init__(self, app: Flask):
self.app = app
def start(self):
log.info("################ API Server initializing #####################")
self._configure()
self._init_dbs()
self._load_services()
self._start_worker()
atexit.register(self._on_worker_stop)
def _configure(self):
self.app.config["SECRET_KEY"] = config.get(
"secure.http.session_secret.apiserver"
)
self.app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get(
"apiserver.pretty_json"
)
def _init_dbs(self):
database.initialize()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if (
upgrade_monitoring
and not empty_db
and (info.es_connection_error or empty_es)
and get_last_server_version() < Version("0.16.0")
):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if info.es_connection_error and not info.missed_es_upgrade:
raise Exception(
"Error starting server: failed connecting to ElasticSearch service"
)
if not info.missed_es_upgrade:
init_es_data()
init_mongo_data()
if (
not info.missed_es_upgrade
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()
def _load_services(self):
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
def _start_worker(self):
check_updates_thread.start()
StatisticsReporter.start()
def _on_worker_stop(self):
ThreadsManager.terminating = True

View File

@ -12,7 +12,7 @@ from boltons.dictutils import subdict
from pyhocon import ConfigFactory
from config import config
from service_repo.base import PartialVersion
from utilities.partial_version import PartialVersion
HERE = Path(__file__)

View File

@ -1,237 +1,18 @@
import atexit
from argparse import ArgumentParser
from hashlib import md5
from flask import Flask, request, Response
from flask import Flask
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from werkzeug.exceptions import BadRequest
import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config, info
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from sync import distributed_lock
from timing_context import TimingContext
from updates import check_updates_thread
from utilities import json
from utilities.threads_manager import ThreadsManager
from app_routes import register_routes
from app_sequence import AppSequence
from config import config
app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors"))
Compress(app)
log = config.logger(__file__)
log.info("################ API Server initializing #####################")
app.config["SECRET_KEY"] = config.get("secure.http.session_secret.apiserver")
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
database.initialize()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if (
upgrade_monitoring
and not empty_db
and (info.es_connection_error or empty_es)
and get_last_server_version() < Version("0.16.0")
):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if info.es_connection_error and not info.missed_es_upgrade:
raise Exception(
"Error starting server: failed connecting to ElasticSearch service"
)
if not info.missed_es_upgrade:
init_es_data()
init_mongo_data()
if (
not info.missed_es_upgrade
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
check_updates_thread.start()
StatisticsReporter.start()
def graceful_shutdown():
ThreadsManager.terminating = True
atexit.register(graceful_shutdown)
@app.before_first_request
def before_app_first_request():
pass
@app.before_request
def before_request():
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
try:
call = create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
headers = {}
if call.result.filename:
headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers:
headers.update(call.result.headers)
response = Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
if call.result.cookies:
for key, value in call.result.cookies.items():
if value is None:
response.set_cookie(key, "", expires=0)
else:
response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def update_call_data(call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def create_api_call(req):
call = None
try:
# Parse the request path
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
)
# Update call data from request
with TimingContext("preprocess", "update_call_data"):
update_call_data(call, req)
except PathParsingError as ex:
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = _call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = _call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
return call
register_routes(app)
AppSequence(app).start()
# =================== MAIN =======================

View File

@ -9,9 +9,9 @@ from six import string_types
import database
from timing_context import TimingContext, TimingStats
from utilities import json
from utilities.partial_version import PartialVersion
from .auth import Identity
from .auth import Payload as AuthPayload
from .base import PartialVersion
from .errors import CallParsingError
from .schema_validator import SchemaValidator
@ -305,6 +305,7 @@ class APICall(DataContainer):
headers=None,
files=None,
trx=None,
**kwargs,
):
super(APICall, self).__init__(data=data, batched_data=batched_data)

View File

@ -5,8 +5,8 @@ from jsonmodels import models
from jsonmodels.errors import FieldNotSupported
from schema import schema
from utilities.partial_version import PartialVersion
from .apicall import APICall
from .base import PartialVersion
from .schema_validator import SchemaValidator
EndpointFunc = Callable[[APICall, Text, models.Base], None]

View File

@ -11,7 +11,7 @@ from apierrors import APIError
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
from api_version import __version__ as _api_version_
from config import config
from service_repo.base import PartialVersion
from utilities.partial_version import PartialVersion
from .apicall import APICall
from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError

View File

@ -69,9 +69,9 @@ from database.model.task.task import (
)
from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint
from service_repo.base import PartialVersion
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
from timing_context import TimingContext
from utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))

View File

@ -5,7 +5,7 @@ from apimodels.organization import Filter
from database.model.base import GetMixin
from database.utils import partition_tags
from service_repo import APICall
from service_repo.base import PartialVersion
from utilities.partial_version import PartialVersion
def get_tags_filter_dictionary(input_: Filter) -> dict:

View File

@ -0,0 +1,7 @@
from semantic_version import Version
class PartialVersion(Version):
def __init__(self, version_string: str):
assert isinstance(version_string, str)
super().__init__(version_string, partial=True)