refac
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from fastapi import (
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from open_webui.internal.db import get_session
|
||||
@@ -49,6 +49,8 @@ from open_webui.models.access_grants import AccessGrants
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.misc import (
|
||||
calculate_sha256,
|
||||
cleanup_response,
|
||||
stream_wrapper,
|
||||
)
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_ollama,
|
||||
@@ -103,14 +105,6 @@ async def send_get_request(url, key=None, user: UserModel = None):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def send_post_request(
|
||||
@@ -124,6 +118,7 @@ async def send_post_request(
|
||||
):
|
||||
|
||||
r = None
|
||||
streaming = False
|
||||
try:
|
||||
session = aiohttp.ClientSession(
|
||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
@@ -168,13 +163,11 @@ async def send_post_request(
|
||||
if content_type:
|
||||
response_headers["Content-Type"] = content_type
|
||||
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
stream_wrapper(r, session),
|
||||
status_code=r.status,
|
||||
headers=response_headers,
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
res = await r.json()
|
||||
@@ -190,7 +183,7 @@ async def send_post_request(
|
||||
detail=detail if e else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not stream:
|
||||
if not streaming:
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from fastapi.responses import (
|
||||
PlainTextResponse,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from open_webui.internal.db import get_session
|
||||
@@ -48,8 +48,10 @@ from open_webui.utils.payload import (
|
||||
apply_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
cleanup_response,
|
||||
convert_logit_bias_input_to_json,
|
||||
stream_chunks_handler,
|
||||
stream_wrapper,
|
||||
)
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
@@ -88,14 +90,6 @@ async def send_get_request(url, key=None, user: UserModel = None):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
def openai_reasoning_model_handler(payload):
|
||||
@@ -1104,12 +1098,9 @@ async def generate_chat_completion(
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
stream_chunks_handler(r.content),
|
||||
stream_wrapper(r, session, stream_chunks_handler),
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -1190,12 +1181,9 @@ async def embeddings(request: Request, form_data: dict, user):
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
stream_wrapper(r, session),
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -1282,12 +1270,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
stream_wrapper(r, session),
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -782,6 +782,30 @@ def extract_urls(text: str) -> list[str]:
|
||||
return url_pattern.findall(text)
|
||||
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def stream_wrapper(response, session, content_handler=None):
|
||||
"""
|
||||
Wrap a stream to ensure cleanup happens even if streaming is interrupted.
|
||||
This is more reliable than BackgroundTask which may not run if client disconnects.
|
||||
"""
|
||||
try:
|
||||
stream = content_handler(response.content) if content_handler else response.content
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
finally:
|
||||
await cleanup_response(response, session)
|
||||
|
||||
|
||||
def stream_chunks_handler(stream: aiohttp.StreamReader):
|
||||
"""
|
||||
Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit.
|
||||
|
||||
Reference in New Issue
Block a user