From 4c989808d6e6e8d8ffcfd050efb7d77539370146 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 19 Dec 2024 11:07:02 -0800 Subject: [PATCH] refac --- backend/open_webui/utils/middleware.py | 115 +++++++++++++------------ 1 file changed, 62 insertions(+), 53 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index a35432566..4ef479652 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -601,69 +601,78 @@ async def process_chat_response(request, response, user, events, metadata, tasks if message: messages = get_message_list(message_map, message.get("id")) - if TASKS.TITLE_GENERATION in tasks: - res = await generate_title( - request, - { - "model": message["model"], - "messages": messages, - "chat_id": metadata["chat_id"], - }, - user, - ) - - if res: - title = ( - res.get("choices", [])[0] - .get("message", {}) - .get("content", message.get("content", "New Chat")) - ) - - Chats.update_chat_title_by_id(metadata["chat_id"], title) - - await event_emitter( + if tasks: + if ( + TASKS.TITLE_GENERATION in tasks + and tasks[TASKS.TITLE_GENERATION] + ): + res = await generate_title( + request, { - "type": "chat:title", - "data": title, - } + "model": message["model"], + "messages": messages, + "chat_id": metadata["chat_id"], + }, + user, ) - if TASKS.TAGS_GENERATION in tasks: - res = await generate_chat_tags( - request, - { - "model": message["model"], - "messages": messages, - "chat_id": metadata["chat_id"], - }, - user, - ) + if res: + title = ( + res.get("choices", [])[0] + .get("message", {}) + .get("content", message.get("content", "New Chat")) + ) - if res: - tags_string = ( - res.get("choices", [])[0] - .get("message", {}) - .get("content", "") - ) - - tags_string = tags_string[ - tags_string.find("{") : tags_string.rfind("}") + 1 - ] - - try: - tags = json.loads(tags_string).get("tags", []) - Chats.update_chat_tags_by_id( - metadata["chat_id"], tags, user + Chats.update_chat_title_by_id( + metadata["chat_id"], title ) await event_emitter( { - "type": "chat:tags", - "data": tags, + "type": "chat:title", + "data": title, } ) - except Exception as e: - print(f"Error: {e}") + + if ( + TASKS.TAGS_GENERATION in tasks + and tasks[TASKS.TAGS_GENERATION] + ): + res = await generate_chat_tags( + request, + { + "model": message["model"], + "messages": messages, + "chat_id": metadata["chat_id"], + }, + user, + ) + + if res: + tags_string = ( + res.get("choices", [])[0] + .get("message", {}) + .get("content", "") + ) + + tags_string = tags_string[ + tags_string.find("{") : tags_string.rfind("}") + 1 + ] + + try: + tags = json.loads(tags_string).get("tags", []) + Chats.update_chat_tags_by_id( + metadata["chat_id"], tags, user + ) + + await event_emitter( + { + "type": "chat:tags", + "data": tags, + } + ) + except Exception as e: + print(f"Error: {e}") except asyncio.CancelledError: print("Task was cancelled!")