refactor into single wrapper

This commit is contained in:
Michael Poluektov 2024-08-15 17:03:42 +01:00
parent 446b2a334a
commit 3164354c0b

View File

@ -681,36 +681,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers["Content-Type"] content_type = response.headers["Content-Type"]
if "text/event-stream" in content_type: is_openai = "text/event-stream" in content_type
return StreamingResponse( is_ollama = "application/x-ndjson" in content_type
self.openai_stream_wrapper(response.body_iterator, data_items), if not is_openai and not is_ollama:
) return response
if "application/x-ndjson" in content_type:
return StreamingResponse( def wrap_item(item):
self.ollama_stream_wrapper(response.body_iterator, data_items), return f"data: {item}\n\n" if is_openai else f"{item}\n"
)
async def stream_wrapper(original_generator, data_items):
for item in data_items:
yield wrap_item(json.dumps(item))
async for data in original_generator:
yield data
return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator:
yield data
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)