Add async callback support to router

This commit is contained in:
clearml 2024-12-23 00:13:44 +02:00
parent aed1b46612
commit be9965a6a5
4 changed files with 89 additions and 30 deletions

View File

@ -14,13 +14,15 @@ from ..utilities.process.mp import SafeQueue
class FastAPIProxy: class FastAPIProxy:
ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
def __init__(self, port, workers=None, default_target=None): def __init__(self, port, workers=None, default_target=None, log_level=None, access_log=None):
self.app = None self.app = None
self.routes = {} self.routes = {}
self.port = port self.port = port
self.message_queue = SafeQueue() self.message_queue = SafeQueue()
self.uvicorn_subprocess = None self.uvicorn_subprocess = None
self.workers = workers self.workers = workers
self.access_log = access_log
self.log_level = None
self._default_target = default_target self._default_target = default_target
self._default_session = None self._default_session = None
self._in_subprocess = False self._in_subprocess = False
@ -58,6 +60,7 @@ class FastAPIProxy:
headers=dict(proxied_response.headers), headers=dict(proxied_response.headers),
status_code=proxied_response.status_code, status_code=proxied_response.status_code,
) )
self.app.add_middleware(DefaultRouteMiddleware) self.app.add_middleware(DefaultRouteMiddleware)
async def proxy( async def proxy(
@ -70,20 +73,24 @@ class FastAPIProxy:
if not route_data: if not route_data:
return Response(status_code=404) return Response(status_code=404)
request = route_data.on_request(request) request = await route_data.on_request(request)
proxied_response = await route_data.session.request( try:
method=request.method, proxied_response = await route_data.session.request(
url=f"{route_data.target_url}/{path}" if path else route_data.target_url, method=request.method,
headers=dict(request.headers), url=f"{route_data.target_url}/{path}" if path else route_data.target_url,
content=await request.body(), headers=dict(request.headers),
params=request.query_params, content=await request.body(),
) params=request.query_params,
proxied_response = Response( )
content=proxied_response.content, proxied_response = Response(
headers=dict(proxied_response.headers), content=proxied_response.content,
status_code=proxied_response.status_code, headers=dict(proxied_response.headers),
) status_code=proxied_response.status_code,
return route_data.on_response(proxied_response, request) )
except Exception as e:
await route_data.on_error(request, e)
raise
return await route_data.on_response(proxied_response, request)
def add_route( def add_route(
self, self,
@ -91,7 +98,8 @@ class FastAPIProxy:
target, target,
request_callback=None, request_callback=None,
response_callback=None, response_callback=None,
endpoint_telemetry=True error_callback=None,
endpoint_telemetry=True,
): ):
if not self._in_subprocess: if not self._in_subprocess:
self.message_queue.put( self.message_queue.put(
@ -102,7 +110,8 @@ class FastAPIProxy:
"target": target, "target": target,
"request_callback": request_callback, "request_callback": request_callback,
"response_callback": response_callback, "response_callback": response_callback,
"endpoint_telemetry": endpoint_telemetry "error_callback": error_callback,
"endpoint_telemetry": endpoint_telemetry,
}, },
} }
) )
@ -116,6 +125,7 @@ class FastAPIProxy:
target, target,
request_callback=request_callback, request_callback=request_callback,
response_callback=response_callback, response_callback=response_callback,
error_callback=error_callback,
session=httpx.AsyncClient(timeout=None), session=httpx.AsyncClient(timeout=None),
) )
if endpoint_telemetry is True: if endpoint_telemetry is True:
@ -164,7 +174,14 @@ class FastAPIProxy:
for route in self.routes.values(): for route in self.routes.values():
route.start_endpoint_telemetry() route.start_endpoint_telemetry()
threading.Thread(target=self._rpc_manager, daemon=True).start() threading.Thread(target=self._rpc_manager, daemon=True).start()
uvicorn.run(self.app, port=self.port, host="0.0.0.0", workers=self.workers) uvicorn.run(
self.app,
port=self.port,
host="0.0.0.0",
workers=self.workers,
log_level=self.log_level,
access_log=self.access_log,
)
def _rpc_manager(self): def _rpc_manager(self):
while True: while True:

View File

@ -4,20 +4,35 @@ from .fastapi_proxy import FastAPIProxy
class HttpProxy: class HttpProxy:
DEFAULT_PORT = 9000 DEFAULT_PORT = 9000
def __init__(self, port=None, workers=None, default_target=None): def __init__(self, port=None, workers=None, default_target=None, log_level=None, access_log=True):
# at the moment, only a fastapi proxy is supported # at the moment, only a fastapi proxy is supported
self.base_proxy = FastAPIProxy(port or self.DEFAULT_PORT, workers=workers, default_target=default_target) self.base_proxy = FastAPIProxy(
port or self.DEFAULT_PORT,
workers=workers,
default_target=default_target,
log_level=log_level,
access_log=access_log,
)
self.base_proxy.start() self.base_proxy.start()
self.port = port self.port = port
self.routes = {} self.routes = {}
def add_route(self, source, target, request_callback=None, response_callback=None, endpoint_telemetry=True): def add_route(
self,
source,
target,
request_callback=None,
response_callback=None,
endpoint_telemetry=True,
error_callback=None,
):
self.routes[source] = self.base_proxy.add_route( self.routes[source] = self.base_proxy.add_route(
source=source, source=source,
target=target, target=target,
request_callback=request_callback, request_callback=request_callback,
response_callback=response_callback, response_callback=response_callback,
endpoint_telemetry=endpoint_telemetry, endpoint_telemetry=endpoint_telemetry,
error_callback=error_callback,
) )
return self.routes[source] return self.routes[source]

View File

@ -1,11 +1,13 @@
import inspect
from .endpoint_telemetry import EndpointTelemetry from .endpoint_telemetry import EndpointTelemetry
class Route: class Route:
def __init__(self, target_url, request_callback=None, response_callback=None, session=None): def __init__(self, target_url, request_callback=None, response_callback=None, session=None, error_callback=None):
self.target_url = target_url self.target_url = target_url
self.request_callback = request_callback self.request_callback = request_callback
self.response_callback = response_callback self.response_callback = response_callback
self.error_callback = error_callback
self.session = session self.session = session
self.persistent_state = {} self.persistent_state = {}
self._endpoint_telemetry = None self._endpoint_telemetry = None
@ -62,18 +64,30 @@ class Route:
self._endpoint_telemetry.stop() self._endpoint_telemetry.stop()
self._endpoint_telemetry = None self._endpoint_telemetry = None
def on_request(self, request): async def on_request(self, request):
new_request = request new_request = request
if self.request_callback: if self.request_callback:
new_request = self.request_callback(request, persistent_state=self.persistent_state) or request new_request = self.request_callback(request, persistent_state=self.persistent_state) or request
if inspect.isawaitable(new_request):
new_request = (await new_request) or request
if self._endpoint_telemetry: if self._endpoint_telemetry:
self._endpoint_telemetry.on_request() self._endpoint_telemetry.on_request()
return new_request return new_request
def on_response(self, response, request): async def on_response(self, response, request):
new_response = response new_response = response
if self.response_callback: if self.response_callback:
new_response = self.response_callback(response, request, persistent_state=self.persistent_state) or response new_response = self.response_callback(response, request, persistent_state=self.persistent_state) or response
if inspect.isawaitable(new_response):
new_response = (await new_response) or response
if self._endpoint_telemetry: if self._endpoint_telemetry:
self._endpoint_telemetry.on_response() self._endpoint_telemetry.on_response()
return new_response return new_response
async def on_error(self, request, error):
on_error_result = None
if self.error_callback:
on_error_result = self.error_callback(request, error, persistent_state=self.persistent_state)
if inspect.isawaitable(on_error_result):
await on_error_result
return on_error_result

View File

@ -43,19 +43,24 @@ class HttpRouter:
self._task = task self._task = task
self._external_endpoint_port = None self._external_endpoint_port = None
self._proxy = None self._proxy = None
self._proxy_params = {"port": HttpProxy.DEFAULT_PORT} self._proxy_params = {"port": HttpProxy.DEFAULT_PORT, "access_log": True}
def set_local_proxy_parameters(self, incoming_port=None, default_target=None): def set_local_proxy_parameters(self, incoming_port=None, default_target=None, log_level=None, access_log=True):
# type: (Optional[int], Optional[str]) -> () # type: (Optional[int], Optional[str], Optional[str], bool) -> ()
""" """
Set the parameters with which the local proxy is initialized Set the parameters with which the local proxy is initialized
:param incoming_port: The incoming port of the proxy :param incoming_port: The incoming port of the proxy
:param default_target: If None, no default target is set. Otherwise, route all traffic :param default_target: If None, no default target is set. Otherwise, route all traffic
that doesn't match a local route created via `create_local_route` to this target that doesn't match a local route created via `create_local_route` to this target
:param log_level: Python log level for the proxy, one of:
'critical', 'error', 'warning', 'info', 'debug', 'trace'
:param access_log: Enable/Disable access log
""" """
self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT
self._proxy_params["default_target"] = default_target self._proxy_params["default_target"] = default_target
self._proxy_params["log_level"] = log_level
self._proxy_params["access_log"] = access_log
def start_local_proxy(self): def start_local_proxy(self):
""" """
@ -69,7 +74,8 @@ class HttpRouter:
target, # type: str target, # type: str
request_callback=None, # type: Callable[Request, Dict] request_callback=None, # type: Callable[Request, Dict]
response_callback=None, # type: Callable[Response, Request, Dict] response_callback=None, # type: Callable[Response, Request, Dict]
endpoint_telemetry=True # type: Union[bool, Dict] endpoint_telemetry=True, # type: Union[bool, Dict]
error_callback=None # type: Callable[Request, Exception, Dict]
): ):
""" """
Create a local route from a source to a target through a proxy. If no proxy instance Create a local route from a source to a target through a proxy. If no proxy instance
@ -88,14 +94,14 @@ class HttpRouter:
The callback must have the following parameters: The callback must have the following parameters:
- request - The intercepted FastAPI request - request - The intercepted FastAPI request
- persistent_state - A dictionary meant to be used as a caching utility object. - persistent_state - A dictionary meant to be used as a caching utility object.
Shared with `response_callback` Shared with `response_callback` and `error_callback`
The callback can return a FastAPI Request, in which case this request will be forwarded to the target The callback can return a FastAPI Request, in which case this request will be forwarded to the target
:param response_callback: A function used to process each response before it is returned by the proxy. :param response_callback: A function used to process each response before it is returned by the proxy.
The callback must have the following parameters: The callback must have the following parameters:
- response - The FastAPI response - response - The FastAPI response
- request - The FastAPI request (after being preprocessed by the proxy) - request - The FastAPI request (after being preprocessed by the proxy)
- persistent_state - A dictionary meant to be used as a caching utility object. - persistent_state - A dictionary meant to be used as a caching utility object.
Shared with `request_callback` Shared with `request_callback` and `error_callback`
The callback can return a FastAPI Response, in which case this response will be forwarded The callback can return a FastAPI Response, in which case this response will be forwarded
:param endpoint_telemetry: If True, enable endpoint telemetry. If False, disable it. :param endpoint_telemetry: If True, enable endpoint telemetry. If False, disable it.
If a dictionary is passed, enable endpoint telemetry with custom parameters. If a dictionary is passed, enable endpoint telemetry with custom parameters.
@ -115,6 +121,12 @@ class HttpRouter:
- input_size - input size of the model - input_size - input size of the model
- input_type - input type expected by the model/endpoint - input_type - input type expected by the model/endpoint
- report_statistics - whether or not to report statistics - report_statistics - whether or not to report statistics
:param error_callback: Callback to be called on request error.
The callback must have the following parameters:
- request - the FastAPI request which caused the error
- error - an exception which indicates which error occurred
- persistent_state - A dictionary meant to be used as a caching utility object.
Shared with `request_callback` and `response_callback`
""" """
self.start_local_proxy() self.start_local_proxy()
self._proxy.add_route( self._proxy.add_route(
@ -123,6 +135,7 @@ class HttpRouter:
request_callback=request_callback, request_callback=request_callback,
response_callback=response_callback, response_callback=response_callback,
endpoint_telemetry=endpoint_telemetry, endpoint_telemetry=endpoint_telemetry,
error_callback=error_callback
) )
def remove_local_route(self, source): def remove_local_route(self, source):