From 589efcdc5fe61754112312ef275fa6f164362efc Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 12:03:47 +0100 Subject: [PATCH] is_chat_completion_request helper, remove nesting --- backend/main.py | 292 +++++++++++++++++++++++------------------------- 1 file changed, 142 insertions(+), 150 deletions(-) diff --git a/backend/main.py b/backend/main.py index 3fc6b8db5..b1cd298a2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -605,129 +605,126 @@ async def chat_completion_files_handler(body): } +def is_chat_completion_request(request): + return request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] + ) + + class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and any( - endpoint in request.url.path - for endpoint in ["/ollama/api/chat", "/chat/completions"] - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) + log.debug(f"request.url.path: {request.url.path}") - try: - body, model, user = await get_body_and_model_and_user(request) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, + try: + body, model, user = await get_body_and_model_and_user(request) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + metadata = { + "chat_id": body.pop("chat_id", None), + "message_id": body.pop("id", None), + "session_id": body.pop("session_id", None), + "valves": body.pop("valves", None), + } + + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + + # Initialize data_items to store additional data to be sent to the client + data_items = [] + + # Initialize context, and citations + contexts = [] + citations = [] + + try: + body, flags = await chat_completion_functions_handler( + body, model, user, __event_emitter__, __event_call__ + ) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + try: + body, flags = await chat_completion_tools_handler( + body, user, __event_emitter__, __event_call__ + ) + + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + print(e) + pass + + try: + body, flags = await chat_completion_files_handler(body) + + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + print(e) + pass + + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], ) - - metadata = { - "chat_id": body.pop("chat_id", None), - "message_id": body.pop("id", None), - "session_id": body.pop("session_id", None), - "valves": body.pop("valves", None), - } - - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - - # Initialize data_items to store additional data to be sent to the client - data_items = [] - - # Initialize context, and citations - contexts = [] - citations = [] - - try: - body, flags = await chat_completion_functions_handler( - body, model, user, __event_emitter__, __event_call__ - ) - except Exception as e: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - try: - body, flags = await chat_completion_files_handler(body) - - contexts.extend(flags.get("contexts", [])) - citations.extend(flags.get("citations", [])) - except Exception as e: - print(e) - pass - - # If context is not empty, insert it into the messages - if len(contexts) > 0: - context_string = "/n".join(contexts).strip() - prompt = get_last_user_message(body["messages"]) - - # Workaround for Ollama 2.0+ system prompt issue - # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": - body["messages"] = prepend_to_first_user_message_content( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - else: - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) - - # If there are citations, add them to the data_items - if len(citations) > 0: - data_items.append({"citations": citations}) - - body["metadata"] = metadata - modified_body_bytes = json.dumps(body).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] - - response = await call_next(request) - if isinstance(response, StreamingResponse): - # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") - if "text/event-stream" in content_type: - return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, data_items), - ) - if "application/x-ndjson" in content_type: - return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, data_items), - ) - - return response else: - return response + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + + # If there are citations, add them to the data_items + if len(citations) > 0: + data_items.append({"citations": citations}) + + body["metadata"] = metadata + modified_body_bytes = json.dumps(body).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] - # If it's not a chat completion request, just pass it through response = await call_next(request) - return response + if isinstance(response, StreamingResponse): + # If it's a streaming response, inject it as SSE event or NDJSON line + content_type = response.headers.get("Content-Type") + if "text/event-stream" in content_type: + return StreamingResponse( + self.openai_stream_wrapper(response.body_iterator, data_items), + ) + if "application/x-ndjson" in content_type: + return StreamingResponse( + self.ollama_stream_wrapper(response.body_iterator, data_items), + ) + + return response + else: + return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} @@ -820,44 +817,39 @@ def filter_pipeline(payload, user): class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path - ): - log.debug(f"request.url.path: {request.url.path}") + if not is_chat_completion_request(request): + return await call_next(request) - # Read the original request body - body = await request.body() - # Decode body to string - body_str = body.decode("utf-8") - # Parse string to JSON - data = json.loads(body_str) if body_str else {} + log.debug(f"request.url.path: {request.url.path}") - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, ) - try: - data = filter_pipeline(data, user) - except Exception as e: - return JSONResponse( - status_code=e.args[0], - content={"detail": e.args[1]}, - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") - # Replace the request body with the modified one - request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length - request.headers.__dict__["_list"] = [ - (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), - *[ - (k, v) - for k, v in request.headers.raw - if k.lower() != b"content-length" - ], - ] + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"], + ] response = await call_next(request) return response