From d795940ced349c3d65dbed8bf007d6ff8cbf4bc1 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 19 Oct 2024 20:34:17 -0700 Subject: [PATCH] feat: chat auto tag --- backend/open_webui/constants.py | 1 + backend/open_webui/main.py | 67 +++++++++++++++++ backend/open_webui/utils/task.py | 18 +++++ src/lib/apis/index.ts | 72 +++++++++++++++++++ src/lib/components/chat/Chat.svelte | 54 +++++++++++--- .../components/chat/Settings/Interface.svelte | 28 ++++++++ .../layout/Sidebar/SearchInput.svelte | 4 +- 7 files changed, 233 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index c52c398cb..d6f33af4a 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -108,6 +108,7 @@ class TASKS(str, Enum): DEFAULT = lambda task="": f"{task if task else 'generation'}" TITLE_GENERATION = "title_generation" + TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" FUNCTION_CALLING = "function_calling" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 169a9ea4f..0d9df6a84 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -134,6 +134,7 @@ from open_webui.utils.misc import ( ) from open_webui.utils.task import ( moa_response_generation_template, + tags_generation_template, search_query_generation_template, title_generation_template, tools_function_calling_generation_template, @@ -1545,6 +1546,72 @@ Prompt: {{prompt:middletruncate:8000}}""" return await generate_chat_completions(form_data=payload, user=user) +@app.post("/api/task/tags/completions") +async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): + print("generate_chat_tags") + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id(model_id) + print(task_model_id) + + template = """### Task: +Generate 1-3 broad tags categorizing the main themes of the chat history. + +### Guidelines: +- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education) +- Only add more specific subdomains if they are strongly represented throughout the conversation +- If content is too short (less than 3 messages) or too diverse, use only ["General"] +- Use the chat's primary language; default to English if multilingual +- Prioritize accuracy over specificity + +### Output: +JSON format: { "tags": ["tag1", "tag2", "tag3"] } + +### Chat History: + +{{MESSAGES:END:6}} +""" + + content = tags_generation_template( + template, form_data["messages"], {"name": user.name} + ) + + print("content", content) + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data}, + } + log.debug(payload) + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + @app.post("/api/task/query/completions") async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): print("generate_search_query") diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index e7cab76cf..7f7876fc5 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -123,6 +123,24 @@ def replace_messages_variable(template: str, messages: list[str]) -> str: return template +def tags_generation_template( + template: str, messages: list[dict], user: Optional[dict] = None +) -> str: + prompt = get_last_user_message(messages) + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def search_query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 843255478..2b3218cb1 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -245,6 +245,78 @@ export const generateTitle = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; }; +export const generateTags = async ( + token: string = '', + model: string, + messages: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/tags/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + messages: messages, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + try { + // Step 1: Safely extract the response string + const response = res?.choices[0]?.message?.content ?? ''; + + // Step 2: Attempt to fix common JSON format issues like single quotes + const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON + + // Step 3: Find the relevant JSON block within the response + const jsonStartIndex = sanitizedResponse.indexOf('{'); + const jsonEndIndex = sanitizedResponse.lastIndexOf('}'); + + // Step 4: Check if we found a valid JSON block (with both `{` and `}`) + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "tags" key, return the tags array; otherwise, return an empty array + if (parsed && parsed.tags) { + return Array.isArray(parsed.tags) ? parsed.tags : []; + } else { + return []; + } + } + + // If no valid JSON block found, return an empty array + return []; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return []; + } +}; + export const generateEmoji = async ( token: string = '', model: string, diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index ebaa2baac..8660034bd 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -10,7 +10,7 @@ import { goto } from '$app/navigation'; import { page } from '$app/stores'; - import type { Unsubscriber, Writable } from 'svelte/store'; + import { get, type Unsubscriber, type Writable } from 'svelte/store'; import type { i18n as i18nType } from 'i18next'; import { WEBUI_BASE_URL } from '$lib/constants'; @@ -20,6 +20,7 @@ config, type Model, models, + tags as allTags, settings, showSidebar, WEBUI_NAME, @@ -46,7 +47,9 @@ import { generateChatCompletion } from '$lib/apis/ollama'; import { + addTagById, createNewChat, + getAllTags, getChatById, getChatList, getTagsById, @@ -62,7 +65,8 @@ generateTitle, generateSearchQuery, chatAction, - generateMoACompletion + generateMoACompletion, + generateTags } from '$lib/apis'; import Banner from '../common/Banner.svelte'; @@ -537,7 +541,10 @@ }); if (chat) { - tags = await getTags(); + tags = await getTagsById(localStorage.token, $chatId).catch(async (error) => { + return []; + }); + const chatContent = chat.chat; if (chatContent) { @@ -1393,6 +1400,10 @@ window.history.replaceState(history.state, '', `/c/${_chatId}`); const title = await generateChatTitle(userPrompt); await setChatTitle(_chatId, title); + + if ($settings?.autoTags ?? true) { + await setChatTags(messages); + } } return _response; @@ -1707,6 +1718,10 @@ window.history.replaceState(history.state, '', `/c/${_chatId}`); const title = await generateChatTitle(userPrompt); await setChatTitle(_chatId, title); + + if ($settings?.autoTags ?? true) { + await setChatTags(messages); + } } return _response; @@ -1893,6 +1908,33 @@ } }; + const setChatTags = async (messages) => { + if (!$temporaryChatEnabled) { + let generatedTags = await generateTags( + localStorage.token, + selectedModels[0], + messages, + $chatId + ).catch((error) => { + console.error(error); + return []; + }); + + const currentTags = await getTagsById(localStorage.token, $chatId); + generatedTags = generatedTags.filter( + (tag) => !currentTags.find((t) => t.id === tag.replaceAll(' ', '_').toLowerCase()) + ); + console.log(generatedTags); + + for (const tag of generatedTags) { + await addTagById(localStorage.token, $chatId, tag); + } + + chat = await getChatById(localStorage.token, $chatId); + allTags.set(await getAllTags(localStorage.token)); + } + }; + const getWebSearchResults = async ( model: string, parentId: string, @@ -1978,12 +2020,6 @@ } }; - const getTags = async () => { - return await getTagsById(localStorage.token, $chatId).catch(async (error) => { - return []; - }); - }; - const initChatHandler = async () => { if (!$temporaryChatEnabled) { chat = await createNewChat(localStorage.token, { diff --git a/src/lib/components/chat/Settings/Interface.svelte b/src/lib/components/chat/Settings/Interface.svelte index 50cdc0559..b58704ffd 100644 --- a/src/lib/components/chat/Settings/Interface.svelte +++ b/src/lib/components/chat/Settings/Interface.svelte @@ -19,6 +19,8 @@ // Addons let titleAutoGenerate = true; + let autoTags = true; + let responseAutoCopy = false; let widescreenMode = false; let splitLargeChunks = false; @@ -112,6 +114,11 @@ }); }; + const toggleAutoTags = async () => { + autoTags = !autoTags; + saveSettings({ autoTags }); + }; + const toggleResponseAutoCopy = async () => { const permission = await navigator.clipboard .readText() @@ -149,6 +156,7 @@ onMount(async () => { titleAutoGenerate = $settings?.title?.auto ?? true; + autoTags = $settings.autoTags ?? true; responseAutoCopy = $settings.responseAutoCopy ?? false; showUsername = $settings.showUsername ?? false; @@ -431,6 +439,26 @@ +
+
+
{$i18n.t('Chat Tags Auto-Generation')}
+ + +
+
+
diff --git a/src/lib/components/layout/Sidebar/SearchInput.svelte b/src/lib/components/layout/Sidebar/SearchInput.svelte index 5e8213308..908fd9db6 100644 --- a/src/lib/components/layout/Sidebar/SearchInput.svelte +++ b/src/lib/components/layout/Sidebar/SearchInput.svelte @@ -144,7 +144,7 @@ {#if filteredTags.length > 0}
Tags
-
+
{#each filteredTags as tag, tagIdx}
-
+
{#each filteredOptions as option, optionIdx}