From 1f53e0922ef9ea19a621498aea83cbe59a59a8d9 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 30 Nov 2024 00:29:27 -0800 Subject: [PATCH] enh: autocompletion --- backend/open_webui/config.py | 5 +- backend/open_webui/main.py | 4 +- backend/open_webui/utils/task.py | 6 +- src/lib/apis/index.ts | 2 + src/lib/components/chat/MessageInput.svelte | 7 ++- .../common/RichTextInput/AutoCompletion.js | 63 ++++++++++++++++--- .../workspace/models/create/+page.svelte | 5 ++ 7 files changed, 77 insertions(+), 15 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 00d80324d..bce1cde10 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1039,7 +1039,10 @@ Output: { "text": "New York City for Italian cuisine." } --- -### Input: +### Context: + +{{MESSAGES:END:6}} + {{TYPE}} {{PROMPT}} #### Output: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 292c4d62c..40724fd30 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1991,7 +1991,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): @app.post("/api/task/auto/completions") async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): - model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -2022,9 +2021,10 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use type = form_data.get("type") prompt = form_data.get("prompt") + messages = form_data.get("messages") content = autocomplete_generation_template( - template, prompt, type, {"name": user.name} + template, prompt, messages, type, {"name": user.name} ) payload = { diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index ea5027e4c..31de1291b 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -214,13 +214,17 @@ def emoji_generation_template( def autocomplete_generation_template( template: str, - prompt: Optional[str] = None, + prompt: str, + messages: Optional[list[dict]] = None, type: Optional[str] = None, user: Optional[dict] = None, ) -> str: template = template.replace("{{TYPE}}", type if type else "") template = replace_prompt_variable(template, prompt) + if messages: + template = replace_messages_variable(template, messages) + template = prompt_template( template, **( diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 12716cece..e2f4fa651 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -403,6 +403,7 @@ export const generateAutoCompletion = async ( token: string = '', model: string, prompt: string, + messages?: object[], type: string = 'search query', ) => { const controller = new AbortController(); @@ -419,6 +420,7 @@ export const generateAutoCompletion = async ( body: JSON.stringify({ model: model, prompt: prompt, + ...(messages && { messages: messages }), type: type, stream: false }) diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 777dff9d4..624b5176a 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -18,7 +18,7 @@ showControls } from '$lib/stores'; - import { blobToFile, findWordIndices } from '$lib/utils'; + import { blobToFile, createMessagesList, findWordIndices } from '$lib/utils'; import { transcribeAudio } from '$lib/apis/audio'; import { uploadFile } from '$lib/apis/files'; import { getTools } from '$lib/apis/tools'; @@ -606,7 +606,10 @@ const res = await generateAutoCompletion( localStorage.token, selectedModelIds.at(0), - text + text, + history?.currentId + ? createMessagesList(history, history.currentId) + : null ).catch((error) => { console.log(error); toast.error(error); diff --git a/src/lib/components/common/RichTextInput/AutoCompletion.js b/src/lib/components/common/RichTextInput/AutoCompletion.js index 2412865ee..4a7ada105 100644 --- a/src/lib/components/common/RichTextInput/AutoCompletion.js +++ b/src/lib/components/common/RichTextInput/AutoCompletion.js @@ -7,6 +7,7 @@ export const AIAutocompletion = Extension.create({ addOptions() { return { generateCompletion: () => Promise.resolve(''), + debounceTime: 1000, } }, @@ -45,6 +46,9 @@ export const AIAutocompletion = Extension.create({ }, addProseMirrorPlugins() { + let debounceTimer = null; + let loading = false; + return [ new Plugin({ key: new PluginKey('aiAutocompletion'), @@ -61,6 +65,8 @@ export const AIAutocompletion = Extension.create({ if (event.key === 'Tab') { if (!node.attrs['data-suggestion']) { // Generate completion + if (loading) return true + loading = true const prompt = node.textContent this.options.generateCompletion(prompt).then(suggestion => { if (suggestion && suggestion.trim() !== '') { @@ -72,6 +78,8 @@ export const AIAutocompletion = Extension.create({ })) } // If suggestion is empty or null, do nothing + }).finally(() => { + loading = false }) } else { // Accept suggestion @@ -87,16 +95,53 @@ export const AIAutocompletion = Extension.create({ ) } return true - } else if (node.attrs['data-suggestion']) { - // Reset suggestion on any other key press - dispatch(state.tr.setNodeMarkup($head.before(), null, { - ...node.attrs, - class: null, - 'data-prompt': null, - 'data-suggestion': null, - })) - } + } else { + if (node.attrs['data-suggestion']) { + // Reset suggestion on any other key press + dispatch(state.tr.setNodeMarkup($head.before(), null, { + ...node.attrs, + class: null, + 'data-prompt': null, + 'data-suggestion': null, + })) + } + + // Set up debounce for AI generation + if (this.options.debounceTime !== null) { + clearTimeout(debounceTimer) + + // Capture current position + const currentPos = $head.before() + + debounceTimer = setTimeout(() => { + const newState = view.state + const newNode = newState.doc.nodeAt(currentPos) + + // Check if the node still exists and is still a paragraph + if (newNode && newNode.type.name === 'paragraph') { + const prompt = newNode.textContent + + if (prompt.trim() !== ''){ + if (loading) return true + loading = true + this.options.generateCompletion(prompt).then(suggestion => { + if (suggestion && suggestion.trim() !== '') { + view.dispatch(newState.tr.setNodeMarkup(currentPos, null, { + ...newNode.attrs, + class: 'ai-autocompletion', + 'data-prompt': prompt, + 'data-suggestion': suggestion, + })) + } + }).finally(() => { + loading = false + }) + } + } + }, this.options.debounceTime) + } + } return false }, }, diff --git a/src/routes/(app)/workspace/models/create/+page.svelte b/src/routes/(app)/workspace/models/create/+page.svelte index e90dd0052..b7280bf3a 100644 --- a/src/routes/(app)/workspace/models/create/+page.svelte +++ b/src/routes/(app)/workspace/models/create/+page.svelte @@ -20,6 +20,11 @@ return; } + if (modelInfo.id === '') { + toast.error('Error: Model ID cannot be empty. Please enter a valid ID to proceed.'); + return; + } + if (modelInfo) { const res = await createNewModel(localStorage.token, { ...modelInfo,