This commit is contained in:
Timothy J. Baek 2024-07-11 13:53:47 -07:00
parent f462744fc8
commit 7d7a29cfb9

View File

@ -618,12 +618,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)}, content={"detail": str(e)},
) )
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in body:
task = body["task"]
del body["task"]
# Extract session_id, chat_id and message_id from the request body # Extract session_id, chat_id and message_id from the request body
session_id = None session_id = None
if "session_id" in body: if "session_id" in body:
@ -703,7 +697,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
"session_id": session_id, "session_id": session_id,
"chat_id": chat_id, "chat_id": chat_id,
"message_id": message_id, "message_id": message_id,
"task": task,
} }
modified_body_bytes = json.dumps(body).encode("utf-8") modified_body_bytes = json.dumps(body).encode("utf-8")
@ -1038,6 +1031,15 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task = None
if "task" in form_data:
task = form_data["task"]
del form_data["task"]
if "metadata" in form_data:
form_data["metadata"]['task'] = task
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":