diff --git a/backend/main.py b/backend/main.py index 11c78645b..febda4ced 100644 --- a/backend/main.py +++ b/backend/main.py @@ -316,7 +316,7 @@ async def get_function_call_response( class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - return_citations = False + data_items = [] if request.method == "POST" and ( "/ollama/api/chat" in request.url.path @@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): # Read the original request body body = await request.body() - # Decode body to string body_str = body.decode("utf-8") - # Parse string to JSON data = json.loads(body_str) if body_str else {} + model_id = data["model"] user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), ) - # Remove the citations from the body - return_citations = data.get("citations", False) - if "citations" in data: - del data["citations"] - # Set the task model - task_model_id = data["model"] + task_model_id = model_id if task_model_id not in app.state.MODELS: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL + skip_files = False prompt = get_last_user_message(data["messages"]) context = "" # If tool_ids field is present, call the functions - - skip_files = False if "tool_ids" in data: print(data["tool_ids"]) for tool_id in data["tool_ids"]: @@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): context += ("\n" if context != "" else "") + rag_context log.debug(f"rag_context: {rag_context}, citations: {citations}") - else: - return_citations = False + + if citations: + data_items.append({"citations": citations}) del data["files"] @@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) print(system_prompt) data["messages"] = add_or_update_system_message( - f"\n{system_prompt}", data["messages"] + system_prompt, data["messages"] ) modified_body_bytes = json.dumps(data).encode("utf-8") @@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): response = await call_next(request) - if return_citations: - # Inject the citations into the response + # If there are data_items to inject into the response + if len(data_items) > 0: if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers.get("Content-Type") if "text/event-stream" in content_type: return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, citations), + self.openai_stream_wrapper(response.body_iterator, data_items), ) if "application/x-ndjson" in content_type: return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, citations), + self.ollama_stream_wrapper(response.body_iterator, data_items), ) return response @@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, citations): - yield f"data: {json.dumps({'citations': citations})}\n\n" + async def openai_stream_wrapper(self, original_generator, data_items): + for item in data_items: + yield f"data: {json.dumps(item)}\n\n" + async for data in original_generator: yield data - async def ollama_stream_wrapper(self, original_generator, citations): - yield f"{json.dumps({'citations': citations})}\n" + async def ollama_stream_wrapper(self, original_generator, data_items): + for item in data_items: + yield f"{json.dumps(item)}\n" + async for data in original_generator: yield data