Add support for streaming in router

This commit is contained in:
clearml 2025-01-19 16:09:53 +02:00
parent 66252b0c36
commit 844d193a4b
3 changed files with 79 additions and 30 deletions

View File

@ -6,6 +6,7 @@ from typing import Optional
import httpx import httpx
import uvicorn import uvicorn
from fastapi import FastAPI, Request, Response from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Match from starlette.routing import Match
@ -16,7 +17,7 @@ from ..utilities.process.mp import SafeQueue
class FastAPIProxy: class FastAPIProxy:
ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
def __init__(self, port, workers=None, default_target=None, log_level=None, access_log=None): def __init__(self, port, workers=None, default_target=None, log_level=None, access_log=None, enable_streaming=True):
self.app = None self.app = None
self.routes = {} self.routes = {}
self.port = port self.port = port
@ -25,6 +26,7 @@ class FastAPIProxy:
self.workers = workers self.workers = workers
self.access_log = access_log self.access_log = access_log
self.log_level = None self.log_level = None
self.enable_streaming = enable_streaming
self._default_target = default_target self._default_target = default_target
self._default_session = None self._default_session = None
self._in_subprocess = False self._in_subprocess = False
@ -50,18 +52,10 @@ class FastAPIProxy:
for route in proxy.app.router.routes: for route in proxy.app.router.routes:
if route.matches(scope)[0] == Match.FULL: if route.matches(scope)[0] == Match.FULL:
return await call_next(request) return await call_next(request)
proxied_response = await proxy._default_session.request( proxied_response = await proxy._send_request(
method=request.method, request, proxy._default_target, proxy._default_target + request.url.path
url=proxy._default_target + request.url.path,
headers=dict(request.headers),
content=await request.body(),
params=request.query_params,
)
return Response(
content=proxied_response.content,
headers=dict(proxied_response.headers),
status_code=proxied_response.status_code,
) )
return await proxy._convert_httpx_response_to_fastapi(proxied_response)
self.app.add_middleware(DefaultRouteMiddleware) self.app.add_middleware(DefaultRouteMiddleware)
@ -77,23 +71,71 @@ class FastAPIProxy:
request = await route_data.on_request(request) request = await route_data.on_request(request)
try: try:
proxied_response = await route_data.session.request( proxied_response = await self._send_request(
method=request.method, request, route_data.session, url=f"{route_data.target_url}/{path}" if path else route_data.target_url
url=f"{route_data.target_url}/{path}" if path else route_data.target_url,
headers=dict(request.headers),
content=await request.body(),
params=request.query_params,
)
proxied_response = Response(
content=proxied_response.content,
headers=dict(proxied_response.headers),
status_code=proxied_response.status_code,
) )
proxied_response = await self._convert_httpx_response_to_fastapi(proxied_response)
except Exception as e: except Exception as e:
await route_data.on_error(request, e) await route_data.on_error(request, e)
raise raise
return await route_data.on_response(proxied_response, request) return await route_data.on_response(proxied_response, request)
async def _send_request(self, request, session, url):
if not self.enable_streaming:
proxied_response = await session.request(
method=request.method,
url=url,
headers=dict(request.headers),
content=await request.body(),
params=request.query_params
)
else:
request = session.build_request(
method=request.method,
url=url,
content=request.stream(),
params=request.query_params,
headers=dict(request.headers),
timeout=httpx.USE_CLIENT_DEFAULT
)
proxied_response = await session.send(
request=request,
auth=httpx.USE_CLIENT_DEFAULT,
follow_redirects=httpx.USE_CLIENT_DEFAULT,
stream=True,
)
return proxied_response
async def _convert_httpx_response_to_fastapi(self, httpx_response):
if self.enable_streaming and httpx_response.headers.get("transfer-encoding", "").lower() == "chunked":
async def upstream_body_generator():
async for chunk in httpx_response.aiter_bytes():
yield chunk
return StreamingResponse(
upstream_body_generator(), status_code=httpx_response.status_code, headers=dict(httpx_response.headers)
)
if not self.enable_streaming:
content = httpx_response.content
else:
content = await httpx_response.aread()
fastapi_response = Response(
content=content,
status_code=httpx_response.status_code,
media_type=httpx_response.headers.get("content-type", None),
headers=dict(httpx_response.headers),
)
# should delete content-length when not present in the original response
# relevant for:
# https://datatracker.ietf.org/doc/html/rfc9112#body.content-length:~:text=MUST%20NOT%20send%20a%20Content%2DLength%20header
if httpx_response.headers.get("content-length") is None:
try:
del fastapi_response.headers["content-length"] # no pop available
except Exception:
pass
return fastapi_response
def add_route( def add_route(
self, self,
source, source,

View File

@ -4,7 +4,9 @@ from .fastapi_proxy import FastAPIProxy
class HttpProxy: class HttpProxy:
DEFAULT_PORT = 9000 DEFAULT_PORT = 9000
def __init__(self, port=None, workers=None, default_target=None, log_level=None, access_log=True): def __init__(
self, port=None, workers=None, default_target=None, log_level=None, access_log=True, enable_streaming=True
):
# at the moment, only a fastapi proxy is supported # at the moment, only a fastapi proxy is supported
self.base_proxy = FastAPIProxy( self.base_proxy = FastAPIProxy(
port or self.DEFAULT_PORT, port or self.DEFAULT_PORT,
@ -12,6 +14,7 @@ class HttpProxy:
default_target=default_target, default_target=default_target,
log_level=log_level, log_level=log_level,
access_log=access_log, access_log=access_log,
enable_streaming=enable_streaming,
) )
self.base_proxy.start() self.base_proxy.start()
self.port = port self.port = port

View File

@ -34,6 +34,7 @@ class HttpRouter:
) )
router.deploy(wait=True) router.deploy(wait=True)
""" """
_instance = None _instance = None
def __init__(self, task): def __init__(self, task):
@ -45,7 +46,9 @@ class HttpRouter:
self._proxy = None self._proxy = None
self._proxy_params = {"port": HttpProxy.DEFAULT_PORT, "access_log": True} self._proxy_params = {"port": HttpProxy.DEFAULT_PORT, "access_log": True}
def set_local_proxy_parameters(self, incoming_port=None, default_target=None, log_level=None, access_log=True): def set_local_proxy_parameters(
self, incoming_port=None, default_target=None, log_level=None, access_log=True, enable_streaming=True
):
# type: (Optional[int], Optional[str], Optional[str], bool) -> () # type: (Optional[int], Optional[str], Optional[str], bool) -> ()
""" """
Set the parameters with which the local proxy is initialized Set the parameters with which the local proxy is initialized
@ -56,11 +59,14 @@ class HttpRouter:
:param log_level: Python log level for the proxy, one of: :param log_level: Python log level for the proxy, one of:
'critical', 'error', 'warning', 'info', 'debug', 'trace' 'critical', 'error', 'warning', 'info', 'debug', 'trace'
:param access_log: Enable/Disable access log :param access_log: Enable/Disable access log
:param enable_streaming: If True, enable streaming of responses with the `transfer-encoding` header set.
If False, no response will be streamed
""" """
self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT
self._proxy_params["default_target"] = default_target self._proxy_params["default_target"] = default_target
self._proxy_params["log_level"] = log_level self._proxy_params["log_level"] = log_level
self._proxy_params["access_log"] = access_log self._proxy_params["access_log"] = access_log
self._proxy_params["enable_streaming"] = enable_streaming
def start_local_proxy(self): def start_local_proxy(self):
""" """
@ -75,7 +81,7 @@ class HttpRouter:
request_callback=None, # type: Callable[Request, Dict] request_callback=None, # type: Callable[Request, Dict]
response_callback=None, # type: Callable[Response, Request, Dict] response_callback=None, # type: Callable[Response, Request, Dict]
endpoint_telemetry=True, # type: Union[bool, Dict] endpoint_telemetry=True, # type: Union[bool, Dict]
error_callback=None # type: Callable[Request, Exception, Dict] error_callback=None, # type: Callable[Request, Exception, Dict]
): ):
""" """
Create a local route from a source to a target through a proxy. If no proxy instance Create a local route from a source to a target through a proxy. If no proxy instance
@ -135,7 +141,7 @@ class HttpRouter:
request_callback=request_callback, request_callback=request_callback,
response_callback=response_callback, response_callback=response_callback,
endpoint_telemetry=endpoint_telemetry, endpoint_telemetry=endpoint_telemetry,
error_callback=error_callback error_callback=error_callback,
) )
def remove_local_route(self, source): def remove_local_route(self, source):
@ -148,9 +154,7 @@ class HttpRouter:
if self._proxy: if self._proxy:
self._proxy.remove_route(source) self._proxy.remove_route(source)
def deploy( def deploy(self, wait=False, wait_interval_seconds=3.0, wait_timeout_seconds=90.0):
self, wait=False, wait_interval_seconds=3.0, wait_timeout_seconds=90.0
):
# type: (Optional[int], str, bool, float, float) -> Optional[Dict] # type: (Optional[int], str, bool, float, float) -> Optional[Dict]
""" """
Start the local HTTP proxy and request an external endpoint for an application Start the local HTTP proxy and request an external endpoint for an application