mirror of
https://github.com/clearml/clearml
synced 2025-05-15 18:05:40 +00:00
Add async callback support to router
This commit is contained in:
parent
aed1b46612
commit
be9965a6a5
@ -14,13 +14,15 @@ from ..utilities.process.mp import SafeQueue
|
|||||||
class FastAPIProxy:
|
class FastAPIProxy:
|
||||||
ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
|
ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
|
||||||
|
|
||||||
def __init__(self, port, workers=None, default_target=None):
|
def __init__(self, port, workers=None, default_target=None, log_level=None, access_log=None):
|
||||||
self.app = None
|
self.app = None
|
||||||
self.routes = {}
|
self.routes = {}
|
||||||
self.port = port
|
self.port = port
|
||||||
self.message_queue = SafeQueue()
|
self.message_queue = SafeQueue()
|
||||||
self.uvicorn_subprocess = None
|
self.uvicorn_subprocess = None
|
||||||
self.workers = workers
|
self.workers = workers
|
||||||
|
self.access_log = access_log
|
||||||
|
self.log_level = None
|
||||||
self._default_target = default_target
|
self._default_target = default_target
|
||||||
self._default_session = None
|
self._default_session = None
|
||||||
self._in_subprocess = False
|
self._in_subprocess = False
|
||||||
@ -58,6 +60,7 @@ class FastAPIProxy:
|
|||||||
headers=dict(proxied_response.headers),
|
headers=dict(proxied_response.headers),
|
||||||
status_code=proxied_response.status_code,
|
status_code=proxied_response.status_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.add_middleware(DefaultRouteMiddleware)
|
self.app.add_middleware(DefaultRouteMiddleware)
|
||||||
|
|
||||||
async def proxy(
|
async def proxy(
|
||||||
@ -70,20 +73,24 @@ class FastAPIProxy:
|
|||||||
if not route_data:
|
if not route_data:
|
||||||
return Response(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
request = route_data.on_request(request)
|
request = await route_data.on_request(request)
|
||||||
proxied_response = await route_data.session.request(
|
try:
|
||||||
method=request.method,
|
proxied_response = await route_data.session.request(
|
||||||
url=f"{route_data.target_url}/{path}" if path else route_data.target_url,
|
method=request.method,
|
||||||
headers=dict(request.headers),
|
url=f"{route_data.target_url}/{path}" if path else route_data.target_url,
|
||||||
content=await request.body(),
|
headers=dict(request.headers),
|
||||||
params=request.query_params,
|
content=await request.body(),
|
||||||
)
|
params=request.query_params,
|
||||||
proxied_response = Response(
|
)
|
||||||
content=proxied_response.content,
|
proxied_response = Response(
|
||||||
headers=dict(proxied_response.headers),
|
content=proxied_response.content,
|
||||||
status_code=proxied_response.status_code,
|
headers=dict(proxied_response.headers),
|
||||||
)
|
status_code=proxied_response.status_code,
|
||||||
return route_data.on_response(proxied_response, request)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await route_data.on_error(request, e)
|
||||||
|
raise
|
||||||
|
return await route_data.on_response(proxied_response, request)
|
||||||
|
|
||||||
def add_route(
|
def add_route(
|
||||||
self,
|
self,
|
||||||
@ -91,7 +98,8 @@ class FastAPIProxy:
|
|||||||
target,
|
target,
|
||||||
request_callback=None,
|
request_callback=None,
|
||||||
response_callback=None,
|
response_callback=None,
|
||||||
endpoint_telemetry=True
|
error_callback=None,
|
||||||
|
endpoint_telemetry=True,
|
||||||
):
|
):
|
||||||
if not self._in_subprocess:
|
if not self._in_subprocess:
|
||||||
self.message_queue.put(
|
self.message_queue.put(
|
||||||
@ -102,7 +110,8 @@ class FastAPIProxy:
|
|||||||
"target": target,
|
"target": target,
|
||||||
"request_callback": request_callback,
|
"request_callback": request_callback,
|
||||||
"response_callback": response_callback,
|
"response_callback": response_callback,
|
||||||
"endpoint_telemetry": endpoint_telemetry
|
"error_callback": error_callback,
|
||||||
|
"endpoint_telemetry": endpoint_telemetry,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -116,6 +125,7 @@ class FastAPIProxy:
|
|||||||
target,
|
target,
|
||||||
request_callback=request_callback,
|
request_callback=request_callback,
|
||||||
response_callback=response_callback,
|
response_callback=response_callback,
|
||||||
|
error_callback=error_callback,
|
||||||
session=httpx.AsyncClient(timeout=None),
|
session=httpx.AsyncClient(timeout=None),
|
||||||
)
|
)
|
||||||
if endpoint_telemetry is True:
|
if endpoint_telemetry is True:
|
||||||
@ -164,7 +174,14 @@ class FastAPIProxy:
|
|||||||
for route in self.routes.values():
|
for route in self.routes.values():
|
||||||
route.start_endpoint_telemetry()
|
route.start_endpoint_telemetry()
|
||||||
threading.Thread(target=self._rpc_manager, daemon=True).start()
|
threading.Thread(target=self._rpc_manager, daemon=True).start()
|
||||||
uvicorn.run(self.app, port=self.port, host="0.0.0.0", workers=self.workers)
|
uvicorn.run(
|
||||||
|
self.app,
|
||||||
|
port=self.port,
|
||||||
|
host="0.0.0.0",
|
||||||
|
workers=self.workers,
|
||||||
|
log_level=self.log_level,
|
||||||
|
access_log=self.access_log,
|
||||||
|
)
|
||||||
|
|
||||||
def _rpc_manager(self):
|
def _rpc_manager(self):
|
||||||
while True:
|
while True:
|
||||||
|
@ -4,20 +4,35 @@ from .fastapi_proxy import FastAPIProxy
|
|||||||
class HttpProxy:
|
class HttpProxy:
|
||||||
DEFAULT_PORT = 9000
|
DEFAULT_PORT = 9000
|
||||||
|
|
||||||
def __init__(self, port=None, workers=None, default_target=None):
|
def __init__(self, port=None, workers=None, default_target=None, log_level=None, access_log=True):
|
||||||
# at the moment, only a fastapi proxy is supported
|
# at the moment, only a fastapi proxy is supported
|
||||||
self.base_proxy = FastAPIProxy(port or self.DEFAULT_PORT, workers=workers, default_target=default_target)
|
self.base_proxy = FastAPIProxy(
|
||||||
|
port or self.DEFAULT_PORT,
|
||||||
|
workers=workers,
|
||||||
|
default_target=default_target,
|
||||||
|
log_level=log_level,
|
||||||
|
access_log=access_log,
|
||||||
|
)
|
||||||
self.base_proxy.start()
|
self.base_proxy.start()
|
||||||
self.port = port
|
self.port = port
|
||||||
self.routes = {}
|
self.routes = {}
|
||||||
|
|
||||||
def add_route(self, source, target, request_callback=None, response_callback=None, endpoint_telemetry=True):
|
def add_route(
|
||||||
|
self,
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
request_callback=None,
|
||||||
|
response_callback=None,
|
||||||
|
endpoint_telemetry=True,
|
||||||
|
error_callback=None,
|
||||||
|
):
|
||||||
self.routes[source] = self.base_proxy.add_route(
|
self.routes[source] = self.base_proxy.add_route(
|
||||||
source=source,
|
source=source,
|
||||||
target=target,
|
target=target,
|
||||||
request_callback=request_callback,
|
request_callback=request_callback,
|
||||||
response_callback=response_callback,
|
response_callback=response_callback,
|
||||||
endpoint_telemetry=endpoint_telemetry,
|
endpoint_telemetry=endpoint_telemetry,
|
||||||
|
error_callback=error_callback,
|
||||||
)
|
)
|
||||||
return self.routes[source]
|
return self.routes[source]
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
import inspect
|
||||||
from .endpoint_telemetry import EndpointTelemetry
|
from .endpoint_telemetry import EndpointTelemetry
|
||||||
|
|
||||||
|
|
||||||
class Route:
|
class Route:
|
||||||
def __init__(self, target_url, request_callback=None, response_callback=None, session=None):
|
def __init__(self, target_url, request_callback=None, response_callback=None, session=None, error_callback=None):
|
||||||
self.target_url = target_url
|
self.target_url = target_url
|
||||||
self.request_callback = request_callback
|
self.request_callback = request_callback
|
||||||
self.response_callback = response_callback
|
self.response_callback = response_callback
|
||||||
|
self.error_callback = error_callback
|
||||||
self.session = session
|
self.session = session
|
||||||
self.persistent_state = {}
|
self.persistent_state = {}
|
||||||
self._endpoint_telemetry = None
|
self._endpoint_telemetry = None
|
||||||
@ -62,18 +64,30 @@ class Route:
|
|||||||
self._endpoint_telemetry.stop()
|
self._endpoint_telemetry.stop()
|
||||||
self._endpoint_telemetry = None
|
self._endpoint_telemetry = None
|
||||||
|
|
||||||
def on_request(self, request):
|
async def on_request(self, request):
|
||||||
new_request = request
|
new_request = request
|
||||||
if self.request_callback:
|
if self.request_callback:
|
||||||
new_request = self.request_callback(request, persistent_state=self.persistent_state) or request
|
new_request = self.request_callback(request, persistent_state=self.persistent_state) or request
|
||||||
|
if inspect.isawaitable(new_request):
|
||||||
|
new_request = (await new_request) or request
|
||||||
if self._endpoint_telemetry:
|
if self._endpoint_telemetry:
|
||||||
self._endpoint_telemetry.on_request()
|
self._endpoint_telemetry.on_request()
|
||||||
return new_request
|
return new_request
|
||||||
|
|
||||||
def on_response(self, response, request):
|
async def on_response(self, response, request):
|
||||||
new_response = response
|
new_response = response
|
||||||
if self.response_callback:
|
if self.response_callback:
|
||||||
new_response = self.response_callback(response, request, persistent_state=self.persistent_state) or response
|
new_response = self.response_callback(response, request, persistent_state=self.persistent_state) or response
|
||||||
|
if inspect.isawaitable(new_response):
|
||||||
|
new_response = (await new_response) or response
|
||||||
if self._endpoint_telemetry:
|
if self._endpoint_telemetry:
|
||||||
self._endpoint_telemetry.on_response()
|
self._endpoint_telemetry.on_response()
|
||||||
return new_response
|
return new_response
|
||||||
|
|
||||||
|
async def on_error(self, request, error):
|
||||||
|
on_error_result = None
|
||||||
|
if self.error_callback:
|
||||||
|
on_error_result = self.error_callback(request, error, persistent_state=self.persistent_state)
|
||||||
|
if inspect.isawaitable(on_error_result):
|
||||||
|
await on_error_result
|
||||||
|
return on_error_result
|
||||||
|
@ -43,19 +43,24 @@ class HttpRouter:
|
|||||||
self._task = task
|
self._task = task
|
||||||
self._external_endpoint_port = None
|
self._external_endpoint_port = None
|
||||||
self._proxy = None
|
self._proxy = None
|
||||||
self._proxy_params = {"port": HttpProxy.DEFAULT_PORT}
|
self._proxy_params = {"port": HttpProxy.DEFAULT_PORT, "access_log": True}
|
||||||
|
|
||||||
def set_local_proxy_parameters(self, incoming_port=None, default_target=None):
|
def set_local_proxy_parameters(self, incoming_port=None, default_target=None, log_level=None, access_log=True):
|
||||||
# type: (Optional[int], Optional[str]) -> ()
|
# type: (Optional[int], Optional[str], Optional[str], bool) -> ()
|
||||||
"""
|
"""
|
||||||
Set the parameters with which the local proxy is initialized
|
Set the parameters with which the local proxy is initialized
|
||||||
|
|
||||||
:param incoming_port: The incoming port of the proxy
|
:param incoming_port: The incoming port of the proxy
|
||||||
:param default_target: If None, no default target is set. Otherwise, route all traffic
|
:param default_target: If None, no default target is set. Otherwise, route all traffic
|
||||||
that doesn't match a local route created via `create_local_route` to this target
|
that doesn't match a local route created via `create_local_route` to this target
|
||||||
|
:param log_level: Python log level for the proxy, one of:
|
||||||
|
'critical', 'error', 'warning', 'info', 'debug', 'trace'
|
||||||
|
:param access_log: Enable/Disable access log
|
||||||
"""
|
"""
|
||||||
self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT
|
self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT
|
||||||
self._proxy_params["default_target"] = default_target
|
self._proxy_params["default_target"] = default_target
|
||||||
|
self._proxy_params["log_level"] = log_level
|
||||||
|
self._proxy_params["access_log"] = access_log
|
||||||
|
|
||||||
def start_local_proxy(self):
|
def start_local_proxy(self):
|
||||||
"""
|
"""
|
||||||
@ -69,7 +74,8 @@ class HttpRouter:
|
|||||||
target, # type: str
|
target, # type: str
|
||||||
request_callback=None, # type: Callable[Request, Dict]
|
request_callback=None, # type: Callable[Request, Dict]
|
||||||
response_callback=None, # type: Callable[Response, Request, Dict]
|
response_callback=None, # type: Callable[Response, Request, Dict]
|
||||||
endpoint_telemetry=True # type: Union[bool, Dict]
|
endpoint_telemetry=True, # type: Union[bool, Dict]
|
||||||
|
error_callback=None # type: Callable[Request, Exception, Dict]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a local route from a source to a target through a proxy. If no proxy instance
|
Create a local route from a source to a target through a proxy. If no proxy instance
|
||||||
@ -88,14 +94,14 @@ class HttpRouter:
|
|||||||
The callback must have the following parameters:
|
The callback must have the following parameters:
|
||||||
- request - The intercepted FastAPI request
|
- request - The intercepted FastAPI request
|
||||||
- persistent_state - A dictionary meant to be used as a caching utility object.
|
- persistent_state - A dictionary meant to be used as a caching utility object.
|
||||||
Shared with `response_callback`
|
Shared with `response_callback` and `error_callback`
|
||||||
The callback can return a FastAPI Request, in which case this request will be forwarded to the target
|
The callback can return a FastAPI Request, in which case this request will be forwarded to the target
|
||||||
:param response_callback: A function used to process each response before it is returned by the proxy.
|
:param response_callback: A function used to process each response before it is returned by the proxy.
|
||||||
The callback must have the following parameters:
|
The callback must have the following parameters:
|
||||||
- response - The FastAPI response
|
- response - The FastAPI response
|
||||||
- request - The FastAPI request (after being preprocessed by the proxy)
|
- request - The FastAPI request (after being preprocessed by the proxy)
|
||||||
- persistent_state - A dictionary meant to be used as a caching utility object.
|
- persistent_state - A dictionary meant to be used as a caching utility object.
|
||||||
Shared with `request_callback`
|
Shared with `request_callback` and `error_callback`
|
||||||
The callback can return a FastAPI Response, in which case this response will be forwarded
|
The callback can return a FastAPI Response, in which case this response will be forwarded
|
||||||
:param endpoint_telemetry: If True, enable endpoint telemetry. If False, disable it.
|
:param endpoint_telemetry: If True, enable endpoint telemetry. If False, disable it.
|
||||||
If a dictionary is passed, enable endpoint telemetry with custom parameters.
|
If a dictionary is passed, enable endpoint telemetry with custom parameters.
|
||||||
@ -115,6 +121,12 @@ class HttpRouter:
|
|||||||
- input_size - input size of the model
|
- input_size - input size of the model
|
||||||
- input_type - input type expected by the model/endpoint
|
- input_type - input type expected by the model/endpoint
|
||||||
- report_statistics - whether or not to report statistics
|
- report_statistics - whether or not to report statistics
|
||||||
|
:param error_callback: Callback to be called on request error.
|
||||||
|
The callback must have the following parameters:
|
||||||
|
- request - the FastAPI request which caused the error
|
||||||
|
- error - an exception which indicates which error occurred
|
||||||
|
- persistent_state - A dictionary meant to be used as a caching utility object.
|
||||||
|
Shared with `request_callback` and `response_callback`
|
||||||
"""
|
"""
|
||||||
self.start_local_proxy()
|
self.start_local_proxy()
|
||||||
self._proxy.add_route(
|
self._proxy.add_route(
|
||||||
@ -123,6 +135,7 @@ class HttpRouter:
|
|||||||
request_callback=request_callback,
|
request_callback=request_callback,
|
||||||
response_callback=response_callback,
|
response_callback=response_callback,
|
||||||
endpoint_telemetry=endpoint_telemetry,
|
endpoint_telemetry=endpoint_telemetry,
|
||||||
|
error_callback=error_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_local_route(self, source):
|
def remove_local_route(self, source):
|
||||||
|
Loading…
Reference in New Issue
Block a user