From 7f74426a228d592cf4822eff94220547e0142e04 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 2 Jun 2024 18:48:45 +0100 Subject: [PATCH] fix: openai streaming cancellation using aiohttp --- backend/apps/ollama/main.py | 2 +- backend/apps/openai/main.py | 44 +++++++++++++++++++++++++------------ 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 76709b0ee..2c84f602e 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -153,7 +153,7 @@ async def cleanup_response( await session.close() -async def post_streaming_url(url, payload): +async def post_streaming_url(url: str, payload: str): r = None try: session = aiohttp.ClientSession() diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6a8347628..ea623345f 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -9,6 +9,7 @@ import json import logging from pydantic import BaseModel +from starlette.background import BackgroundTask from apps.webui.models.models import Models from apps.webui.models.users import Users @@ -194,6 +195,16 @@ async def fetch_url(url, key): return None +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + def merge_models_lists(model_lists): log.debug(f"merge_models_lists {model_lists}") merged_list = [] @@ -426,40 +437,45 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): headers["Content-Type"] = "application/json" r = None + session = None + streaming = False try: - r = requests.request( - method=request.method, - url=target_url, - data=payload if payload else body, - headers=headers, - stream=True, + session = aiohttp.ClientSession() + r = await session.request( + method=request.method, url=target_url, data=payload, headers=headers ) r.raise_for_status() # Check if response is SSE if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True return StreamingResponse( - r.iter_content(chunk_size=8192), - status_code=r.status_code, + r.content, + status_code=r.status, headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), ) else: - response_data = r.json() + response_data = await r.json() return response_data except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" if r is not None: try: - res = r.json() + res = await r.json() print(res) if "error" in res: error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except: error_detail = f"External: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail - ) + raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + finally: + if not streaming and session: + if r: + r.close() + await session.close()