From 3164354c0b86517a2e168ac50f44c4abe1d319d8 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 15 Aug 2024 17:03:42 +0100 Subject: [PATCH] refactor into single wrapper --- backend/main.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/backend/main.py b/backend/main.py index d539834ed..411c33e1c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -681,36 +681,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): response = await call_next(request) if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers["Content-Type"] - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) + is_openai = "text/event-stream" in content_type + is_ollama = "application/x-ndjson" in content_type + if not is_openai and not is_ollama: + return response + + def wrap_item(item): + return f"data: {item}\n\n" if is_openai else f"{item}\n" + + async def stream_wrapper(original_generator, data_items): + for item in data_items: + yield wrap_item(json.dumps(item)) + + async for data in original_generator: + yield data + + return StreamingResponse(stream_wrapper(response.body_iterator, data_items)) return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"data: {json.dumps(item)}\n\n" - - async for data in original_generator: - yield data - - async def ollama_stream_wrapper(self, original_generator, data_items): - for item in data_items: - yield f"{json.dumps(item)}\n" - - async for data in original_generator: - yield data - app.add_middleware(ChatCompletionMiddleware)