refac: chat completion middleware

This commit is contained in:
Timothy J. Baek 2024-06-20 02:06:10 -07:00
parent 448ca9d836
commit 6b8a7b9939

View File

@ -316,7 +316,7 @@ async def get_function_call_response(
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
data_items = []
if request.method == "POST" and (
"/ollama/api/chat" in request.url.path
@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
model_id = data["model"]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Set the task model
task_model_id = data["model"]
task_model_id = model_id
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
skip_files = False
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
skip_files = False
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
else:
return_citations = False
if citations:
data_items.append({"citations": citations})
del data["files"]
@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
system_prompt, data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
if return_citations:
# Inject the citations into the response
# If there are data_items to inject into the response
if len(data_items) > 0:
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
self.ollama_stream_wrapper(response.body_iterator, data_items),
)
return response
@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async def openai_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async def ollama_stream_wrapper(self, original_generator, data_items):
for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator:
yield data