feat: autocompletion

This commit is contained in:
Timothy Jaeryang Baek 2024-11-29 00:16:49 -08:00
parent 0e8e9820d0
commit a07213b5be
6 changed files with 98 additions and 14 deletions

View File

@ -1037,6 +1037,12 @@ Only output a continuation. If you are unsure how to proceed, output nothing.
<context>Search</context> <context>Search</context>
<text>Best destinations for hiking in</text> <text>Best destinations for hiking in</text>
**Output**: Europe, such as the Alps or the Scottish Highlands. **Output**: Europe, such as the Alps or the Scottish Highlands.
### Input:
<context>{{CONTEXT}}</context>
<text>
{{PROMPT}}
</text>
""" """

View File

@ -1991,7 +1991,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
@app.post("/api/task/auto/completions") @app.post("/api/task/auto/completions")
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
context = form_data.get("context")
model_list = await get_all_models() model_list = await get_all_models()
models = {model["id"]: model for model in model_list} models = {model["id"]: model for model in model_list}
@ -2021,8 +2020,11 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
else: else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
context = form_data.get("context")
prompt = form_data.get("prompt")
content = autocomplete_generation_template( content = autocomplete_generation_template(
template, form_data["messages"], context, {"name": user.name} template, prompt, context, {"name": user.name}
) )
payload = { payload = {
@ -2036,6 +2038,8 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
}, },
} }
print(payload)
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user, models) payload = filter_pipeline(payload, user, models)

View File

@ -53,7 +53,9 @@ def prompt_template(
def replace_prompt_variable(template: str, prompt: str) -> str: def replace_prompt_variable(template: str, prompt: str) -> str:
def replacement_function(match): def replacement_function(match):
full_match = match.group(0) full_match = match.group(
0
).lower() # Normalize to lowercase for consistent handling
start_length = match.group(1) start_length = match.group(1)
end_length = match.group(2) end_length = match.group(2)
middle_length = match.group(3) middle_length = match.group(3)
@ -73,11 +75,9 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
return f"{start}...{end}" return f"{start}...{end}"
return "" return ""
template = re.sub( # Updated regex pattern to make it case-insensitive with the `(?i)` flag
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}"
replacement_function, template = re.sub(pattern, replacement_function, template)
template,
)
return template return template
@ -214,15 +214,12 @@ def emoji_generation_template(
def autocomplete_generation_template( def autocomplete_generation_template(
template: str, template: str,
messages: list[dict], prompt: Optional[str] = None,
context: Optional[str] = None, context: Optional[str] = None,
user: Optional[dict] = None, user: Optional[dict] = None,
) -> str: ) -> str:
prompt = get_last_user_message(messages)
template = template.replace("{{CONTEXT}}", context if context else "") template = template.replace("{{CONTEXT}}", context if context else "")
template = replace_prompt_variable(template, prompt) template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template( template = prompt_template(
template, template,

View File

@ -397,6 +397,53 @@ export const generateQueries = async (
} }
}; };
export const generateAutoCompletion = async (
token: string = '',
model: string,
prompt: string,
context: string = 'search',
) => {
const controller = new AbortController();
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, {
signal: controller.signal,
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
prompt: prompt,
context: context,
stream: false
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
const response = res?.choices[0]?.message?.content ?? '';
return response;
};
export const generateMoACompletion = async ( export const generateMoACompletion = async (
token: string = '', token: string = '',
model: string, model: string,

View File

@ -34,6 +34,8 @@
import Commands from './MessageInput/Commands.svelte'; import Commands from './MessageInput/Commands.svelte';
import XMark from '../icons/XMark.svelte'; import XMark from '../icons/XMark.svelte';
import RichTextInput from '../common/RichTextInput.svelte'; import RichTextInput from '../common/RichTextInput.svelte';
import { generateAutoCompletion } from '$lib/apis';
import { error, text } from '@sveltejs/kit';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -47,6 +49,9 @@
export let atSelectedModel: Model | undefined; export let atSelectedModel: Model | undefined;
export let selectedModels: ['']; export let selectedModels: [''];
let selectedModelIds = [];
$: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels;
export let history; export let history;
export let prompt = ''; export let prompt = '';
@ -581,6 +586,7 @@
> >
<RichTextInput <RichTextInput
bind:this={chatInputElement} bind:this={chatInputElement}
bind:value={prompt}
id="chat-input" id="chat-input"
messageInput={true} messageInput={true}
shiftEnter={!$mobile || shiftEnter={!$mobile ||
@ -592,7 +598,25 @@
placeholder={placeholder ? placeholder : $i18n.t('Send a Message')} placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
largeTextAsFile={$settings?.largeTextAsFile ?? false} largeTextAsFile={$settings?.largeTextAsFile ?? false}
autocomplete={true} autocomplete={true}
bind:value={prompt} generateAutoCompletion={async (text) => {
if (selectedModelIds.length === 0 || !selectedModelIds.at(0)) {
toast.error($i18n.t('Please select a model first.'));
}
const res = await generateAutoCompletion(
localStorage.token,
selectedModelIds.at(0),
text
).catch((error) => {
console.log(error);
toast.error(error);
return null;
});
console.log(res);
return res;
}}
on:keydown={async (e) => { on:keydown={async (e) => {
e = e.detail.event; e = e.detail.event;

View File

@ -34,6 +34,7 @@
export let value = ''; export let value = '';
export let id = ''; export let id = '';
export let generateAutoCompletion: Function = async () => null;
export let autocomplete = false; export let autocomplete = false;
export let messageInput = false; export let messageInput = false;
export let shiftEnter = false; export let shiftEnter = false;
@ -159,7 +160,12 @@
return null; return null;
} }
return 'AI-generated suggestion'; const suggestion = await generateAutoCompletion(text).catch(() => null);
if (!suggestion || suggestion.trim().length === 0) {
return null;
}
return suggestion;
} }
}) })
] ]