refac: chat completion middleware

This commit is contained in:
Timothy J. Baek 2024-06-20 02:06:10 -07:00
parent 448ca9d836
commit 6b8a7b9939

View File

@ -316,7 +316,7 @@ async def get_function_call_response(
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False data_items = []
if request.method == "POST" and ( if request.method == "POST" and (
"/ollama/api/chat" in request.url.path "/ollama/api/chat" in request.url.path
@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Read the original request body # Read the original request body
body = await request.body() body = await request.body()
# Decode body to string
body_str = body.decode("utf-8") body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
model_id = data["model"]
user = get_current_user( user = get_current_user(
request, request,
get_http_authorization_cred(request.headers.get("Authorization")), get_http_authorization_cred(request.headers.get("Authorization")),
) )
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Set the task model # Set the task model
task_model_id = data["model"] task_model_id = model_id
if task_model_id not in app.state.MODELS: if task_model_id not in app.state.MODELS:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
): ):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL task_model_id = app.state.config.TASK_MODEL_EXTERNAL
skip_files = False
prompt = get_last_user_message(data["messages"]) prompt = get_last_user_message(data["messages"])
context = "" context = ""
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
skip_files = False
if "tool_ids" in data: if "tool_ids" in data:
print(data["tool_ids"]) print(data["tool_ids"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context += ("\n" if context != "" else "") + rag_context context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}") log.debug(f"rag_context: {rag_context}, citations: {citations}")
else:
return_citations = False if citations:
data_items.append({"citations": citations})
del data["files"] del data["files"]
@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
) )
print(system_prompt) print(system_prompt)
data["messages"] = add_or_update_system_message( data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"] system_prompt, data["messages"]
) )
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
if return_citations: # If there are data_items to inject into the response
# Inject the citations into the response if len(data_items) > 0:
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line # If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type") content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type: if "text/event-stream" in content_type:
return StreamingResponse( return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations), self.openai_stream_wrapper(response.body_iterator, data_items),
) )
if "application/x-ndjson" in content_type: if "application/x-ndjson" in content_type:
return StreamingResponse( return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations), self.ollama_stream_wrapper(response.body_iterator, data_items),
) )
return response return response
@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations): async def openai_stream_wrapper(self, original_generator, data_items):
yield f"data: {json.dumps({'citations': citations})}\n\n" for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
async def ollama_stream_wrapper(self, original_generator, citations): async def ollama_stream_wrapper(self, original_generator, data_items):
yield f"{json.dumps({'citations': citations})}\n" for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator: async for data in original_generator:
yield data yield data