fix: openai streaming cancellation using aiohttp

This commit is contained in:
Jun Siang Cheah 2024-06-02 18:48:45 +01:00
parent 4dd51badfe
commit 7f74426a22
2 changed files with 31 additions and 15 deletions

View File

@ -153,7 +153,7 @@ async def cleanup_response(
await session.close() await session.close()
async def post_streaming_url(url, payload): async def post_streaming_url(url: str, payload: str):
r = None r = None
try: try:
session = aiohttp.ClientSession() session = aiohttp.ClientSession()

View File

@ -9,6 +9,7 @@ import json
import logging import logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}") log.debug(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
@ -426,40 +437,45 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
session = None
streaming = False
try: try:
r = requests.request( session = aiohttp.ClientSession()
method=request.method, r = await session.request(
url=target_url, method=request.method, url=target_url, data=payload, headers=headers
data=payload if payload else body,
headers=headers,
stream=True,
) )
r.raise_for_status() r.raise_for_status()
# Check if response is SSE # Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), r.content,
status_code=r.status_code, status_code=r.status,
headers=dict(r.headers), headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
) )
else: else:
response_data = r.json() response_data = await r.json()
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = await r.json()
print(res) print(res)
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
raise HTTPException( finally:
status_code=r.status_code if r else 500, detail=error_detail if not streaming and session:
) if r:
r.close()
await session.close()