This commit is contained in:
Timothy Jaeryang Baek 2024-12-19 11:07:02 -08:00
parent c3e8cd03b2
commit 4c989808d6

View File

@ -601,69 +601,78 @@ async def process_chat_response(request, response, user, events, metadata, tasks
if message:
messages = get_message_list(message_map, message.get("id"))
if TASKS.TITLE_GENERATION in tasks:
res = await generate_title(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
title = (
res.get("choices", [])[0]
.get("message", {})
.get("content", message.get("content", "New Chat"))
)
Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter(
if tasks:
if (
TASKS.TITLE_GENERATION in tasks
and tasks[TASKS.TITLE_GENERATION]
):
res = await generate_title(
request,
{
"type": "chat:title",
"data": title,
}
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if TASKS.TAGS_GENERATION in tasks:
res = await generate_chat_tags(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
title = (
res.get("choices", [])[0]
.get("message", {})
.get("content", message.get("content", "New Chat"))
)
if res:
tags_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
Chats.update_chat_title_by_id(
metadata["chat_id"], title
)
await event_emitter(
{
"type": "chat:tags",
"data": tags,
"type": "chat:title",
"data": title,
}
)
except Exception as e:
print(f"Error: {e}")
if (
TASKS.TAGS_GENERATION in tasks
and tasks[TASKS.TAGS_GENERATION]
):
res = await generate_chat_tags(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
tags_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
)
await event_emitter(
{
"type": "chat:tags",
"data": tags,
}
)
except Exception as e:
print(f"Error: {e}")
except asyncio.CancelledError:
print("Task was cancelled!")