move task to metadata

This commit is contained in:
Michael Poluektov 2024-08-10 13:04:01 +01:00
parent 556141cdd8
commit 0c9119d619

View File

@ -317,7 +317,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": f"Query: {prompt}"},
], ],
"stream": False, "stream": False,
"task": str(TASKS.FUNCTION_CALLING), "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
} }
try: try:
@ -788,19 +788,21 @@ def filter_pipeline(payload, user):
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "": if key == "":
headers = {"Authorization": f"Bearer {key}"} continue
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status() headers = {"Authorization": f"Bearer {key}"}
payload = r.json() r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") print(f"Connection error: {e}")
@ -1086,13 +1088,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
if task := form_data.pop("task", None):
if "metadata" in form_data:
form_data["metadata"]["task"] = task
else:
form_data["metadata"] = {"task": task}
if model.get("pipe"): if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
@ -1469,7 +1464,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False, "stream": False,
"max_tokens": 50, "max_tokens": 50,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": str(TASKS.TITLE_GENERATION), "metadata": {"task": str(TASKS.TITLE_GENERATION)},
} }
log.debug(payload) log.debug(payload)
@ -1522,7 +1517,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"max_tokens": 30, "max_tokens": 30,
"task": str(TASKS.QUERY_GENERATION), "metadata": {"task": str(TASKS.QUERY_GENERATION)},
} }
print(payload) print(payload)
@ -1579,7 +1574,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
"max_tokens": 4, "max_tokens": 4,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": str(TASKS.EMOJI_GENERATION), "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
} }
log.debug(payload) log.debug(payload)