From 0c9119d6199f61c623ea06e0899a2f36c7ecc09d Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sat, 10 Aug 2024 13:04:01 +0100 Subject: [PATCH] move task to metadata --- backend/main.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/backend/main.py b/backend/main.py index bd7b6c8f6..0099aabb8 100644 --- a/backend/main.py +++ b/backend/main.py @@ -317,7 +317,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, - "task": str(TASKS.FUNCTION_CALLING), + "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } try: @@ -788,19 +788,21 @@ def filter_pipeline(payload, user): url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) + if key == "": + continue - r.raise_for_status() - payload = r.json() + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() except Exception as e: # Handle connection error here print(f"Connection error: {e}") @@ -1086,13 +1088,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] - # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. - if task := form_data.pop("task", None): - if "metadata" in form_data: - form_data["metadata"]["task"] = task - else: - form_data["metadata"] = {"task": task} - if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": @@ -1469,7 +1464,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.TITLE_GENERATION), + "metadata": {"task": str(TASKS.TITLE_GENERATION)}, } log.debug(payload) @@ -1522,7 +1517,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": str(TASKS.QUERY_GENERATION), + "metadata": {"task": str(TASKS.QUERY_GENERATION)}, } print(payload) @@ -1579,7 +1574,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": str(TASKS.EMOJI_GENERATION), + "metadata": {"task": str(TASKS.EMOJI_GENERATION)}, } log.debug(payload)