mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 14:52:23 +00:00
feat: autocompletion
This commit is contained in:
parent
0e8e9820d0
commit
a07213b5be
@ -1037,6 +1037,12 @@ Only output a continuation. If you are unsure how to proceed, output nothing.
|
|||||||
<context>Search</context>
|
<context>Search</context>
|
||||||
<text>Best destinations for hiking in</text>
|
<text>Best destinations for hiking in</text>
|
||||||
**Output**: Europe, such as the Alps or the Scottish Highlands.
|
**Output**: Europe, such as the Alps or the Scottish Highlands.
|
||||||
|
|
||||||
|
### Input:
|
||||||
|
<context>{{CONTEXT}}</context>
|
||||||
|
<text>
|
||||||
|
{{PROMPT}}
|
||||||
|
</text>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -1991,7 +1991,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
|
|
||||||
@app.post("/api/task/auto/completions")
|
@app.post("/api/task/auto/completions")
|
||||||
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
|
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
|
||||||
context = form_data.get("context")
|
|
||||||
|
|
||||||
model_list = await get_all_models()
|
model_list = await get_all_models()
|
||||||
models = {model["id"]: model for model in model_list}
|
models = {model["id"]: model for model in model_list}
|
||||||
@ -2021,8 +2020,11 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
|
|||||||
else:
|
else:
|
||||||
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
context = form_data.get("context")
|
||||||
|
prompt = form_data.get("prompt")
|
||||||
|
|
||||||
content = autocomplete_generation_template(
|
content = autocomplete_generation_template(
|
||||||
template, form_data["messages"], context, {"name": user.name}
|
template, prompt, context, {"name": user.name}
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@ -2036,6 +2038,8 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(payload)
|
||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = filter_pipeline(payload, user, models)
|
||||||
|
@ -53,7 +53,9 @@ def prompt_template(
|
|||||||
|
|
||||||
def replace_prompt_variable(template: str, prompt: str) -> str:
|
def replace_prompt_variable(template: str, prompt: str) -> str:
|
||||||
def replacement_function(match):
|
def replacement_function(match):
|
||||||
full_match = match.group(0)
|
full_match = match.group(
|
||||||
|
0
|
||||||
|
).lower() # Normalize to lowercase for consistent handling
|
||||||
start_length = match.group(1)
|
start_length = match.group(1)
|
||||||
end_length = match.group(2)
|
end_length = match.group(2)
|
||||||
middle_length = match.group(3)
|
middle_length = match.group(3)
|
||||||
@ -73,11 +75,9 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
|
|||||||
return f"{start}...{end}"
|
return f"{start}...{end}"
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
template = re.sub(
|
# Updated regex pattern to make it case-insensitive with the `(?i)` flag
|
||||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}"
|
||||||
replacement_function,
|
template = re.sub(pattern, replacement_function, template)
|
||||||
template,
|
|
||||||
)
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@ -214,15 +214,12 @@ def emoji_generation_template(
|
|||||||
|
|
||||||
def autocomplete_generation_template(
|
def autocomplete_generation_template(
|
||||||
template: str,
|
template: str,
|
||||||
messages: list[dict],
|
prompt: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
user: Optional[dict] = None,
|
user: Optional[dict] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
prompt = get_last_user_message(messages)
|
|
||||||
template = template.replace("{{CONTEXT}}", context if context else "")
|
template = template.replace("{{CONTEXT}}", context if context else "")
|
||||||
|
|
||||||
template = replace_prompt_variable(template, prompt)
|
template = replace_prompt_variable(template, prompt)
|
||||||
template = replace_messages_variable(template, messages)
|
|
||||||
|
|
||||||
template = prompt_template(
|
template = prompt_template(
|
||||||
template,
|
template,
|
||||||
|
@ -397,6 +397,53 @@ export const generateQueries = async (
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
export const generateAutoCompletion = async (
|
||||||
|
token: string = '',
|
||||||
|
model: string,
|
||||||
|
prompt: string,
|
||||||
|
context: string = 'search',
|
||||||
|
) => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, {
|
||||||
|
signal: controller.signal,
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: model,
|
||||||
|
prompt: prompt,
|
||||||
|
context: context,
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
if ('detail' in err) {
|
||||||
|
error = err.detail;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = res?.choices[0]?.message?.content ?? '';
|
||||||
|
return response;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
export const generateMoACompletion = async (
|
export const generateMoACompletion = async (
|
||||||
token: string = '',
|
token: string = '',
|
||||||
model: string,
|
model: string,
|
||||||
|
@ -34,6 +34,8 @@
|
|||||||
import Commands from './MessageInput/Commands.svelte';
|
import Commands from './MessageInput/Commands.svelte';
|
||||||
import XMark from '../icons/XMark.svelte';
|
import XMark from '../icons/XMark.svelte';
|
||||||
import RichTextInput from '../common/RichTextInput.svelte';
|
import RichTextInput from '../common/RichTextInput.svelte';
|
||||||
|
import { generateAutoCompletion } from '$lib/apis';
|
||||||
|
import { error, text } from '@sveltejs/kit';
|
||||||
|
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
@ -47,6 +49,9 @@
|
|||||||
export let atSelectedModel: Model | undefined;
|
export let atSelectedModel: Model | undefined;
|
||||||
export let selectedModels: [''];
|
export let selectedModels: [''];
|
||||||
|
|
||||||
|
let selectedModelIds = [];
|
||||||
|
$: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels;
|
||||||
|
|
||||||
export let history;
|
export let history;
|
||||||
|
|
||||||
export let prompt = '';
|
export let prompt = '';
|
||||||
@ -581,6 +586,7 @@
|
|||||||
>
|
>
|
||||||
<RichTextInput
|
<RichTextInput
|
||||||
bind:this={chatInputElement}
|
bind:this={chatInputElement}
|
||||||
|
bind:value={prompt}
|
||||||
id="chat-input"
|
id="chat-input"
|
||||||
messageInput={true}
|
messageInput={true}
|
||||||
shiftEnter={!$mobile ||
|
shiftEnter={!$mobile ||
|
||||||
@ -592,7 +598,25 @@
|
|||||||
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
|
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
|
||||||
largeTextAsFile={$settings?.largeTextAsFile ?? false}
|
largeTextAsFile={$settings?.largeTextAsFile ?? false}
|
||||||
autocomplete={true}
|
autocomplete={true}
|
||||||
bind:value={prompt}
|
generateAutoCompletion={async (text) => {
|
||||||
|
if (selectedModelIds.length === 0 || !selectedModelIds.at(0)) {
|
||||||
|
toast.error($i18n.t('Please select a model first.'));
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await generateAutoCompletion(
|
||||||
|
localStorage.token,
|
||||||
|
selectedModelIds.at(0),
|
||||||
|
text
|
||||||
|
).catch((error) => {
|
||||||
|
console.log(error);
|
||||||
|
toast.error(error);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(res);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}}
|
||||||
on:keydown={async (e) => {
|
on:keydown={async (e) => {
|
||||||
e = e.detail.event;
|
e = e.detail.event;
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@
|
|||||||
export let value = '';
|
export let value = '';
|
||||||
export let id = '';
|
export let id = '';
|
||||||
|
|
||||||
|
export let generateAutoCompletion: Function = async () => null;
|
||||||
export let autocomplete = false;
|
export let autocomplete = false;
|
||||||
export let messageInput = false;
|
export let messageInput = false;
|
||||||
export let shiftEnter = false;
|
export let shiftEnter = false;
|
||||||
@ -159,7 +160,12 @@
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return 'AI-generated suggestion';
|
const suggestion = await generateAutoCompletion(text).catch(() => null);
|
||||||
|
if (!suggestion || suggestion.trim().length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return suggestion;
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user