is_chat_completion_request helper, remove nesting

This commit is contained in:
Michael Poluektov 2024-08-10 12:03:47 +01:00
parent 3befadb29f
commit 589efcdc5f

View File

@ -605,12 +605,17 @@ async def chat_completion_files_handler(body):
} }
class ChatCompletionMiddleware(BaseHTTPMiddleware): def is_chat_completion_request(request):
async def dispatch(self, request: Request, call_next): return request.method == "POST" and any(
if request.method == "POST" and any(
endpoint in request.url.path endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"] for endpoint in ["/ollama/api/chat", "/chat/completions"]
): )
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not is_chat_completion_request(request):
return await call_next(request)
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
try: try:
@ -701,11 +706,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [ request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")), (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[ *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
] ]
response = await call_next(request) response = await call_next(request)
@ -725,10 +726,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
else: else:
return response return response
# If it's not a chat completion request, just pass it through
response = await call_next(request)
return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
@ -820,10 +817,9 @@ def filter_pipeline(payload, user):
class PipelineMiddleware(BaseHTTPMiddleware): class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if request.method == "POST" and ( if not is_chat_completion_request(request):
"/ollama/api/chat" in request.url.path return await call_next(request)
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
# Read the original request body # Read the original request body
@ -852,11 +848,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [ request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")), (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[ *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
] ]
response = await call_next(request) response = await call_next(request)