From 4dd77b785ad4af9b9b14905679acdeb8ebab6940 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 11 Jul 2024 14:12:44 -0700 Subject: [PATCH] fix --- backend/apps/ollama/main.py | 2 +- backend/main.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 09df6cc33..9df8719e9 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -895,8 +895,8 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): form_data = OpenAIChatCompletionForm(**form_data) + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} - payload = {**form_data} if "metadata" in payload: del payload["metadata"] diff --git a/backend/main.py b/backend/main.py index 636bbf646..01c2fde2a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -317,7 +317,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, - "task": TASKS.FUNCTION_CALLING, + "task": str(TASKS.FUNCTION_CALLING), } try: @@ -632,8 +632,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] - - __event_emitter__ = await get_event_emitter( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} ) @@ -1037,12 +1035,16 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u task = form_data["task"] del form_data["task"] - if "metadata" in form_data: - form_data["metadata"]['task'] = task + if task: + if "metadata" in form_data: + form_data["metadata"]["task"] = task + else: + form_data["metadata"] = {"task": task} if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": + print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) @@ -1311,7 +1313,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": TASKS.TITLE_GENERATION, + "task": str(TASKS.TITLE_GENERATION), } log.debug(payload) @@ -1364,7 +1366,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": TASKS.QUERY_GENERATION, + "task": str(TASKS.QUERY_GENERATION), } print(payload) @@ -1421,7 +1423,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": TASKS.EMOJI_GENERATION, + "task": str(TASKS.EMOJI_GENERATION), } log.debug(payload)