diff --git a/backend/main.py b/backend/main.py index 1ce3f699a..636bbf646 100644 --- a/backend/main.py +++ b/backend/main.py @@ -618,12 +618,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) - # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - task = None - if "task" in body: - task = body["task"] - del body["task"] - # Extract session_id, chat_id and message_id from the request body session_id = None if "session_id" in body: @@ -703,7 +697,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "session_id": session_id, "chat_id": chat_id, "message_id": message_id, - "task": task, } modified_body_bytes = json.dumps(body).encode("utf-8") @@ -1038,6 +1031,15 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] + # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. + task = None + if "task" in form_data: + task = form_data["task"] + del form_data["task"] + + if "metadata" in form_data: + form_data["metadata"]['task'] = task + if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama":