Add support for streaming in router

This commit is contained in:
clearml 2025-01-19 16:09:53 +02:00
parent 66252b0c36
commit 844d193a4b
3 changed files with 79 additions and 30 deletions

View File

@ -6,6 +6,7 @@ from typing import Optional
import httpx
import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Match
@ -16,7 +17,7 @@ from ..utilities.process.mp import SafeQueue
class FastAPIProxy:
ALL_REST_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
def __init__(self, port, workers=None, default_target=None, log_level=None, access_log=None):
def __init__(self, port, workers=None, default_target=None, log_level=None, access_log=None, enable_streaming=True):
self.app = None
self.routes = {}
self.port = port
@ -25,6 +26,7 @@ class FastAPIProxy:
self.workers = workers
self.access_log = access_log
self.log_level = None
self.enable_streaming = enable_streaming
self._default_target = default_target
self._default_session = None
self._in_subprocess = False
@ -50,18 +52,10 @@ class FastAPIProxy:
for route in proxy.app.router.routes:
if route.matches(scope)[0] == Match.FULL:
return await call_next(request)
proxied_response = await proxy._default_session.request(
method=request.method,
url=proxy._default_target + request.url.path,
headers=dict(request.headers),
content=await request.body(),
params=request.query_params,
)
return Response(
content=proxied_response.content,
headers=dict(proxied_response.headers),
status_code=proxied_response.status_code,
proxied_response = await proxy._send_request(
request, proxy._default_target, proxy._default_target + request.url.path
)
return await proxy._convert_httpx_response_to_fastapi(proxied_response)
self.app.add_middleware(DefaultRouteMiddleware)
@ -77,23 +71,71 @@ class FastAPIProxy:
request = await route_data.on_request(request)
try:
proxied_response = await route_data.session.request(
method=request.method,
url=f"{route_data.target_url}/{path}" if path else route_data.target_url,
headers=dict(request.headers),
content=await request.body(),
params=request.query_params,
)
proxied_response = Response(
content=proxied_response.content,
headers=dict(proxied_response.headers),
status_code=proxied_response.status_code,
proxied_response = await self._send_request(
request, route_data.session, url=f"{route_data.target_url}/{path}" if path else route_data.target_url
)
proxied_response = await self._convert_httpx_response_to_fastapi(proxied_response)
except Exception as e:
await route_data.on_error(request, e)
raise
return await route_data.on_response(proxied_response, request)
async def _send_request(self, request, session, url):
if not self.enable_streaming:
proxied_response = await session.request(
method=request.method,
url=url,
headers=dict(request.headers),
content=await request.body(),
params=request.query_params
)
else:
request = session.build_request(
method=request.method,
url=url,
content=request.stream(),
params=request.query_params,
headers=dict(request.headers),
timeout=httpx.USE_CLIENT_DEFAULT
)
proxied_response = await session.send(
request=request,
auth=httpx.USE_CLIENT_DEFAULT,
follow_redirects=httpx.USE_CLIENT_DEFAULT,
stream=True,
)
return proxied_response
async def _convert_httpx_response_to_fastapi(self, httpx_response):
if self.enable_streaming and httpx_response.headers.get("transfer-encoding", "").lower() == "chunked":
async def upstream_body_generator():
async for chunk in httpx_response.aiter_bytes():
yield chunk
return StreamingResponse(
upstream_body_generator(), status_code=httpx_response.status_code, headers=dict(httpx_response.headers)
)
if not self.enable_streaming:
content = httpx_response.content
else:
content = await httpx_response.aread()
fastapi_response = Response(
content=content,
status_code=httpx_response.status_code,
media_type=httpx_response.headers.get("content-type", None),
headers=dict(httpx_response.headers),
)
# should delete content-length when not present in the original response
# relevant for:
# https://datatracker.ietf.org/doc/html/rfc9112#body.content-length:~:text=MUST%20NOT%20send%20a%20Content%2DLength%20header
if httpx_response.headers.get("content-length") is None:
try:
del fastapi_response.headers["content-length"] # no pop available
except Exception:
pass
return fastapi_response
def add_route(
self,
source,

View File

@ -4,7 +4,9 @@ from .fastapi_proxy import FastAPIProxy
class HttpProxy:
DEFAULT_PORT = 9000
def __init__(self, port=None, workers=None, default_target=None, log_level=None, access_log=True):
def __init__(
self, port=None, workers=None, default_target=None, log_level=None, access_log=True, enable_streaming=True
):
# at the moment, only a fastapi proxy is supported
self.base_proxy = FastAPIProxy(
port or self.DEFAULT_PORT,
@ -12,6 +14,7 @@ class HttpProxy:
default_target=default_target,
log_level=log_level,
access_log=access_log,
enable_streaming=enable_streaming,
)
self.base_proxy.start()
self.port = port

View File

@ -34,6 +34,7 @@ class HttpRouter:
)
router.deploy(wait=True)
"""
_instance = None
def __init__(self, task):
@ -45,7 +46,9 @@ class HttpRouter:
self._proxy = None
self._proxy_params = {"port": HttpProxy.DEFAULT_PORT, "access_log": True}
def set_local_proxy_parameters(self, incoming_port=None, default_target=None, log_level=None, access_log=True):
def set_local_proxy_parameters(
self, incoming_port=None, default_target=None, log_level=None, access_log=True, enable_streaming=True
):
# type: (Optional[int], Optional[str], Optional[str], bool) -> ()
"""
Set the parameters with which the local proxy is initialized
@ -56,11 +59,14 @@ class HttpRouter:
:param log_level: Python log level for the proxy, one of:
'critical', 'error', 'warning', 'info', 'debug', 'trace'
:param access_log: Enable/Disable access log
:param enable_streaming: If True, enable streaming of responses with the `transfer-encoding` header set.
If False, no response will be streamed
"""
self._proxy_params["port"] = incoming_port or HttpProxy.DEFAULT_PORT
self._proxy_params["default_target"] = default_target
self._proxy_params["log_level"] = log_level
self._proxy_params["access_log"] = access_log
self._proxy_params["enable_streaming"] = enable_streaming
def start_local_proxy(self):
"""
@ -75,7 +81,7 @@ class HttpRouter:
request_callback=None, # type: Callable[Request, Dict]
response_callback=None, # type: Callable[Response, Request, Dict]
endpoint_telemetry=True, # type: Union[bool, Dict]
error_callback=None # type: Callable[Request, Exception, Dict]
error_callback=None, # type: Callable[Request, Exception, Dict]
):
"""
Create a local route from a source to a target through a proxy. If no proxy instance
@ -135,7 +141,7 @@ class HttpRouter:
request_callback=request_callback,
response_callback=response_callback,
endpoint_telemetry=endpoint_telemetry,
error_callback=error_callback
error_callback=error_callback,
)
def remove_local_route(self, source):
@ -148,9 +154,7 @@ class HttpRouter:
if self._proxy:
self._proxy.remove_route(source)
def deploy(
self, wait=False, wait_interval_seconds=3.0, wait_timeout_seconds=90.0
):
def deploy(self, wait=False, wait_interval_seconds=3.0, wait_timeout_seconds=90.0):
# type: (Optional[int], str, bool, float, float) -> Optional[Dict]
"""
Start the local HTTP proxy and request an external endpoint for an application