mirror of
https://github.com/clearml/clearml-server
synced 2025-04-20 22:24:29 +00:00
Introduce app startup sequence
This commit is contained in:
parent
df65e1c7ad
commit
c67a56eb8d
155
apiserver/app_routes.py
Normal file
155
apiserver/app_routes.py
Normal file
@ -0,0 +1,155 @@
|
||||
from flask import Flask, request, Response
|
||||
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
from apiserver.service_repo.auth import AuthType
|
||||
from apiserver.service_repo.errors import PathParsingError
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def before_app_first_request():
|
||||
pass
|
||||
|
||||
|
||||
def before_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
if request.path.startswith("/rq"):
|
||||
return
|
||||
|
||||
try:
|
||||
call = create_api_call(request)
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
headers = {}
|
||||
if call.result.filename:
|
||||
headers[
|
||||
"Content-Disposition"
|
||||
] = f"attachment; filename={call.result.filename}"
|
||||
|
||||
if call.result.headers:
|
||||
headers.update(call.result.headers)
|
||||
|
||||
response = Response(
|
||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
||||
)
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
if value is None:
|
||||
response.set_cookie(
|
||||
key, "", expires=0, **config.get("apiserver.auth.cookies")
|
||||
)
|
||||
else:
|
||||
response.set_cookie(
|
||||
key, value, **config.get("apiserver.auth.cookies")
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
|
||||
def update_call_data(call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
items = []
|
||||
for i, line in enumerate(req.data.splitlines()):
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if not isinstance(event, dict):
|
||||
raise BadRequest(
|
||||
f"json lines must contain objects, found: {type(event).__name__}"
|
||||
)
|
||||
items.append(event)
|
||||
except ValueError as e:
|
||||
msg = f"{e} in batch item #{i}"
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
call.data = json_body or form or {}
|
||||
|
||||
|
||||
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
|
||||
call = call or APICall(
|
||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||
)
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
return call
|
||||
|
||||
|
||||
def create_api_call(req):
|
||||
call = None
|
||||
try:
|
||||
# Parse the request path
|
||||
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
auth_cookie = req.cookies.get(
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
)
|
||||
headers = (
|
||||
{}
|
||||
if not auth_cookie
|
||||
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
|
||||
)
|
||||
headers.update(
|
||||
list(req.headers.items())
|
||||
) # add (possibly override with) the headers
|
||||
|
||||
# Construct call instance
|
||||
call = APICall(
|
||||
endpoint_name=endpoint_name,
|
||||
remote_addr=req.remote_addr,
|
||||
endpoint_version=endpoint_version,
|
||||
headers=headers,
|
||||
files=req.files,
|
||||
host=req.host,
|
||||
)
|
||||
|
||||
# Update call data from request
|
||||
with TimingContext("preprocess", "update_call_data"):
|
||||
update_call_data(call, req)
|
||||
|
||||
except PathParsingError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
|
||||
call.log_api = False
|
||||
except BadRequest as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.description, 400)
|
||||
except BaseError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
||||
except Exception as ex:
|
||||
log.exception("Error creating call")
|
||||
call = _call_or_empty_with_error(call, req, ex.args[0] if ex.args else type(ex).__name__, 500)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def register_routes(app: Flask):
|
||||
app.before_first_request(before_app_first_request)
|
||||
app.before_request(before_request)
|
98
apiserver/app_sequence.py
Normal file
98
apiserver/app_sequence.py
Normal file
@ -0,0 +1,98 @@
|
||||
import atexit
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask
|
||||
from semantic_version import Version
|
||||
|
||||
import database
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config, info
|
||||
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
|
||||
from mongo.initialize import (
|
||||
init_mongo_data,
|
||||
pre_populate_data,
|
||||
check_mongo_empty,
|
||||
get_last_server_version,
|
||||
)
|
||||
from service_repo import ServiceRepo
|
||||
from sync import distributed_lock
|
||||
from updates import check_updates_thread
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class AppSequence:
|
||||
def __init__(self, app: Flask):
|
||||
self.app = app
|
||||
|
||||
def start(self):
|
||||
log.info("################ API Server initializing #####################")
|
||||
self._configure()
|
||||
self._init_dbs()
|
||||
self._load_services()
|
||||
self._start_worker()
|
||||
atexit.register(self._on_worker_stop)
|
||||
|
||||
def _configure(self):
|
||||
self.app.config["SECRET_KEY"] = config.get(
|
||||
"secure.http.session_secret.apiserver"
|
||||
)
|
||||
self.app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get(
|
||||
"apiserver.pretty_json"
|
||||
)
|
||||
|
||||
def _init_dbs(self):
|
||||
database.initialize()
|
||||
|
||||
# build a key that uniquely identifies specific mongo instance
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
|
||||
upgrade_monitoring = config.get(
|
||||
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
|
||||
)
|
||||
try:
|
||||
empty_es = check_elastic_empty()
|
||||
except ElasticConnectionError as err:
|
||||
if not upgrade_monitoring:
|
||||
raise
|
||||
log.error(err)
|
||||
info.es_connection_error = True
|
||||
|
||||
empty_db = check_mongo_empty()
|
||||
if (
|
||||
upgrade_monitoring
|
||||
and not empty_db
|
||||
and (info.es_connection_error or empty_es)
|
||||
and get_last_server_version() < Version("0.16.0")
|
||||
):
|
||||
log.info(f"ES database seems not migrated")
|
||||
info.missed_es_upgrade = True
|
||||
|
||||
if info.es_connection_error and not info.missed_es_upgrade:
|
||||
raise Exception(
|
||||
"Error starting server: failed connecting to ElasticSearch service"
|
||||
)
|
||||
|
||||
if not info.missed_es_upgrade:
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
|
||||
if (
|
||||
not info.missed_es_upgrade
|
||||
and empty_db
|
||||
and config.get("apiserver.pre_populate.enabled", False)
|
||||
):
|
||||
pre_populate_data()
|
||||
|
||||
def _load_services(self):
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
def _start_worker(self):
|
||||
check_updates_thread.start()
|
||||
StatisticsReporter.start()
|
||||
|
||||
def _on_worker_stop(self):
|
||||
ThreadsManager.terminating = True
|
@ -12,7 +12,7 @@ from boltons.dictutils import subdict
|
||||
from pyhocon import ConfigFactory
|
||||
|
||||
from config import config
|
||||
from service_repo.base import PartialVersion
|
||||
from utilities.partial_version import PartialVersion
|
||||
|
||||
HERE = Path(__file__)
|
||||
|
||||
|
@ -1,237 +1,18 @@
|
||||
import atexit
|
||||
from argparse import ArgumentParser
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask import Flask
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config, info
|
||||
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
|
||||
from mongo.initialize import (
|
||||
init_mongo_data,
|
||||
pre_populate_data,
|
||||
check_mongo_empty,
|
||||
get_last_server_version,
|
||||
)
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from sync import distributed_lock
|
||||
from timing_context import TimingContext
|
||||
from updates import check_updates_thread
|
||||
from utilities import json
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from app_routes import register_routes
|
||||
from app_sequence import AppSequence
|
||||
from config import config
|
||||
|
||||
app = Flask(__name__, static_url_path="/static")
|
||||
CORS(app, **config.get("apiserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
log.info("################ API Server initializing #####################")
|
||||
|
||||
app.config["SECRET_KEY"] = config.get("secure.http.session_secret.apiserver")
|
||||
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
||||
|
||||
database.initialize()
|
||||
|
||||
# build a key that uniquely identifies specific mongo instance
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
|
||||
upgrade_monitoring = config.get(
|
||||
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
|
||||
)
|
||||
try:
|
||||
empty_es = check_elastic_empty()
|
||||
except ElasticConnectionError as err:
|
||||
if not upgrade_monitoring:
|
||||
raise
|
||||
log.error(err)
|
||||
info.es_connection_error = True
|
||||
|
||||
empty_db = check_mongo_empty()
|
||||
if (
|
||||
upgrade_monitoring
|
||||
and not empty_db
|
||||
and (info.es_connection_error or empty_es)
|
||||
and get_last_server_version() < Version("0.16.0")
|
||||
):
|
||||
log.info(f"ES database seems not migrated")
|
||||
info.missed_es_upgrade = True
|
||||
|
||||
if info.es_connection_error and not info.missed_es_upgrade:
|
||||
raise Exception(
|
||||
"Error starting server: failed connecting to ElasticSearch service"
|
||||
)
|
||||
|
||||
if not info.missed_es_upgrade:
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
|
||||
if (
|
||||
not info.missed_es_upgrade
|
||||
and empty_db
|
||||
and config.get("apiserver.pre_populate.enabled", False)
|
||||
):
|
||||
pre_populate_data()
|
||||
|
||||
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
|
||||
check_updates_thread.start()
|
||||
StatisticsReporter.start()
|
||||
|
||||
|
||||
def graceful_shutdown():
|
||||
ThreadsManager.terminating = True
|
||||
|
||||
|
||||
atexit.register(graceful_shutdown)
|
||||
|
||||
|
||||
@app.before_first_request
|
||||
def before_app_first_request():
|
||||
pass
|
||||
|
||||
|
||||
@app.before_request
|
||||
def before_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
|
||||
try:
|
||||
call = create_api_call(request)
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
headers = {}
|
||||
if call.result.filename:
|
||||
headers[
|
||||
"Content-Disposition"
|
||||
] = f"attachment; filename={call.result.filename}"
|
||||
|
||||
if call.result.headers:
|
||||
headers.update(call.result.headers)
|
||||
|
||||
response = Response(
|
||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
||||
)
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
if value is None:
|
||||
response.set_cookie(key, "", expires=0)
|
||||
else:
|
||||
response.set_cookie(
|
||||
key, value, **config.get("apiserver.auth.cookies")
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
|
||||
def update_call_data(call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
items = []
|
||||
for i, line in enumerate(req.data.splitlines()):
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if not isinstance(event, dict):
|
||||
raise BadRequest(
|
||||
f"json lines must contain objects, found: {type(event).__name__}"
|
||||
)
|
||||
items.append(event)
|
||||
except ValueError as e:
|
||||
msg = f"{e} in batch item #{i}"
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
call.data = json_body or form or {}
|
||||
|
||||
|
||||
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
|
||||
call = call or APICall(
|
||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||
)
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
return call
|
||||
|
||||
|
||||
def create_api_call(req):
|
||||
call = None
|
||||
try:
|
||||
# Parse the request path
|
||||
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
auth_cookie = req.cookies.get(
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
)
|
||||
headers = (
|
||||
{}
|
||||
if not auth_cookie
|
||||
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
|
||||
)
|
||||
headers.update(
|
||||
list(req.headers.items())
|
||||
) # add (possibly override with) the headers
|
||||
|
||||
# Construct call instance
|
||||
call = APICall(
|
||||
endpoint_name=endpoint_name,
|
||||
remote_addr=req.remote_addr,
|
||||
endpoint_version=endpoint_version,
|
||||
headers=headers,
|
||||
files=req.files,
|
||||
)
|
||||
|
||||
# Update call data from request
|
||||
with TimingContext("preprocess", "update_call_data"):
|
||||
update_call_data(call, req)
|
||||
|
||||
except PathParsingError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
|
||||
call.log_api = False
|
||||
except BadRequest as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.description, 400)
|
||||
except BaseError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
||||
except Exception as ex:
|
||||
log.exception("Error creating call")
|
||||
call = _call_or_empty_with_error(
|
||||
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
|
||||
)
|
||||
|
||||
return call
|
||||
register_routes(app)
|
||||
AppSequence(app).start()
|
||||
|
||||
|
||||
# =================== MAIN =======================
|
||||
|
@ -9,9 +9,9 @@ from six import string_types
|
||||
import database
|
||||
from timing_context import TimingContext, TimingStats
|
||||
from utilities import json
|
||||
from utilities.partial_version import PartialVersion
|
||||
from .auth import Identity
|
||||
from .auth import Payload as AuthPayload
|
||||
from .base import PartialVersion
|
||||
from .errors import CallParsingError
|
||||
from .schema_validator import SchemaValidator
|
||||
|
||||
@ -305,6 +305,7 @@ class APICall(DataContainer):
|
||||
headers=None,
|
||||
files=None,
|
||||
trx=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(APICall, self).__init__(data=data, batched_data=batched_data)
|
||||
|
||||
|
@ -5,8 +5,8 @@ from jsonmodels import models
|
||||
from jsonmodels.errors import FieldNotSupported
|
||||
|
||||
from schema import schema
|
||||
from utilities.partial_version import PartialVersion
|
||||
from .apicall import APICall
|
||||
from .base import PartialVersion
|
||||
from .schema_validator import SchemaValidator
|
||||
|
||||
EndpointFunc = Callable[[APICall, Text, models.Base], None]
|
||||
|
@ -11,7 +11,7 @@ from apierrors import APIError
|
||||
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
|
||||
from api_version import __version__ as _api_version_
|
||||
from config import config
|
||||
from service_repo.base import PartialVersion
|
||||
from utilities.partial_version import PartialVersion
|
||||
from .apicall import APICall
|
||||
from .endpoint import Endpoint
|
||||
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
|
||||
|
@ -69,9 +69,9 @@ from database.model.task.task import (
|
||||
)
|
||||
from database.utils import get_fields, parse_from_call
|
||||
from service_repo import APICall, endpoint
|
||||
from service_repo.base import PartialVersion
|
||||
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities.partial_version import PartialVersion
|
||||
|
||||
task_fields = set(Task.get_fields())
|
||||
task_script_fields = set(get_fields(Script))
|
||||
|
@ -5,7 +5,7 @@ from apimodels.organization import Filter
|
||||
from database.model.base import GetMixin
|
||||
from database.utils import partition_tags
|
||||
from service_repo import APICall
|
||||
from service_repo.base import PartialVersion
|
||||
from utilities.partial_version import PartialVersion
|
||||
|
||||
|
||||
def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
|
7
apiserver/utilities/partial_version.py
Normal file
7
apiserver/utilities/partial_version.py
Normal file
@ -0,0 +1,7 @@
|
||||
from semantic_version import Version
|
||||
|
||||
|
||||
class PartialVersion(Version):
|
||||
def __init__(self, version_string: str):
|
||||
assert isinstance(version_string, str)
|
||||
super().__init__(version_string, partial=True)
|
Loading…
Reference in New Issue
Block a user