diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index a941627cf..72b16797c 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1037,6 +1037,12 @@ Only output a continuation. If you are unsure how to proceed, output nothing. Search Best destinations for hiking in **Output**: Europe, such as the Alps or the Scottish Highlands. + +### Input: +{{CONTEXT}} + +{{PROMPT}} + """ diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 311bf3968..69dd84d8b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1991,7 +1991,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): @app.post("/api/task/auto/completions") async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): - context = form_data.get("context") model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -2021,8 +2020,11 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use else: template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + context = form_data.get("context") + prompt = form_data.get("prompt") + content = autocomplete_generation_template( - template, form_data["messages"], context, {"name": user.name} + template, prompt, context, {"name": user.name} ) payload = { @@ -2036,6 +2038,8 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use }, } + print(payload) + # Handle pipeline filters try: payload = filter_pipeline(payload, user, models) diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 61e46f5ac..401d546d0 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -53,7 +53,9 @@ def prompt_template( def replace_prompt_variable(template: str, prompt: str) -> str: def replacement_function(match): - full_match = match.group(0) + full_match = match.group( + 0 + ).lower() # Normalize to lowercase for consistent handling start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) @@ -73,11 +75,9 @@ def replace_prompt_variable(template: str, prompt: str) -> str: return f"{start}...{end}" return "" - template = re.sub( - r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", - replacement_function, - template, - ) + # Updated regex pattern to make it case-insensitive with the `(?i)` flag + pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}" + template = re.sub(pattern, replacement_function, template) return template @@ -214,15 +214,12 @@ def emoji_generation_template( def autocomplete_generation_template( template: str, - messages: list[dict], + prompt: Optional[str] = None, context: Optional[str] = None, user: Optional[dict] = None, ) -> str: - prompt = get_last_user_message(messages) template = template.replace("{{CONTEXT}}", context if context else "") - template = replace_prompt_variable(template, prompt) - template = replace_messages_variable(template, messages) template = prompt_template( template, diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 2e9e836a8..6f72685ca 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -397,6 +397,53 @@ export const generateQueries = async ( } }; + + +export const generateAutoCompletion = async ( + token: string = '', + model: string, + prompt: string, + context: string = 'search', +) => { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + context: context, + stream: false + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + const response = res?.choices[0]?.message?.content ?? ''; + return response; +}; + + export const generateMoACompletion = async ( token: string = '', model: string, diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 6ebb29428..777dff9d4 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -34,6 +34,8 @@ import Commands from './MessageInput/Commands.svelte'; import XMark from '../icons/XMark.svelte'; import RichTextInput from '../common/RichTextInput.svelte'; + import { generateAutoCompletion } from '$lib/apis'; + import { error, text } from '@sveltejs/kit'; const i18n = getContext('i18n'); @@ -47,6 +49,9 @@ export let atSelectedModel: Model | undefined; export let selectedModels: ['']; + let selectedModelIds = []; + $: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels; + export let history; export let prompt = ''; @@ -581,6 +586,7 @@ > { + if (selectedModelIds.length === 0 || !selectedModelIds.at(0)) { + toast.error($i18n.t('Please select a model first.')); + } + + const res = await generateAutoCompletion( + localStorage.token, + selectedModelIds.at(0), + text + ).catch((error) => { + console.log(error); + toast.error(error); + return null; + }); + + console.log(res); + + return res; + }} on:keydown={async (e) => { e = e.detail.event; diff --git a/src/lib/components/common/RichTextInput.svelte b/src/lib/components/common/RichTextInput.svelte index e2ca453ac..6b208d6e2 100644 --- a/src/lib/components/common/RichTextInput.svelte +++ b/src/lib/components/common/RichTextInput.svelte @@ -34,6 +34,7 @@ export let value = ''; export let id = ''; + export let generateAutoCompletion: Function = async () => null; export let autocomplete = false; export let messageInput = false; export let shiftEnter = false; @@ -159,7 +160,12 @@ return null; } - return 'AI-generated suggestion'; + const suggestion = await generateAutoCompletion(text).catch(() => null); + if (!suggestion || suggestion.trim().length === 0) { + return null; + } + + return suggestion; } }) ]