enh: autocompletion

This commit is contained in:
Timothy Jaeryang Baek 2024-11-30 00:29:27 -08:00
parent ba6dc71810
commit 1f53e0922e
7 changed files with 77 additions and 15 deletions

View File

@ -1039,7 +1039,10 @@ Output:
{ "text": "New York City for Italian cuisine." } { "text": "New York City for Italian cuisine." }
--- ---
### Input: ### Context:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>
<type>{{TYPE}}</type> <type>{{TYPE}}</type>
<text>{{PROMPT}}</text> <text>{{PROMPT}}</text>
#### Output: #### Output:

View File

@ -1991,7 +1991,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
@app.post("/api/task/auto/completions") @app.post("/api/task/auto/completions")
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
model_list = await get_all_models() model_list = await get_all_models()
models = {model["id"]: model for model in model_list} models = {model["id"]: model for model in model_list}
@ -2022,9 +2021,10 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
type = form_data.get("type") type = form_data.get("type")
prompt = form_data.get("prompt") prompt = form_data.get("prompt")
messages = form_data.get("messages")
content = autocomplete_generation_template( content = autocomplete_generation_template(
template, prompt, type, {"name": user.name} template, prompt, messages, type, {"name": user.name}
) )
payload = { payload = {

View File

@ -214,13 +214,17 @@ def emoji_generation_template(
def autocomplete_generation_template( def autocomplete_generation_template(
template: str, template: str,
prompt: Optional[str] = None, prompt: str,
messages: Optional[list[dict]] = None,
type: Optional[str] = None, type: Optional[str] = None,
user: Optional[dict] = None, user: Optional[dict] = None,
) -> str: ) -> str:
template = template.replace("{{TYPE}}", type if type else "") template = template.replace("{{TYPE}}", type if type else "")
template = replace_prompt_variable(template, prompt) template = replace_prompt_variable(template, prompt)
if messages:
template = replace_messages_variable(template, messages)
template = prompt_template( template = prompt_template(
template, template,
**( **(

View File

@ -403,6 +403,7 @@ export const generateAutoCompletion = async (
token: string = '', token: string = '',
model: string, model: string,
prompt: string, prompt: string,
messages?: object[],
type: string = 'search query', type: string = 'search query',
) => { ) => {
const controller = new AbortController(); const controller = new AbortController();
@ -419,6 +420,7 @@ export const generateAutoCompletion = async (
body: JSON.stringify({ body: JSON.stringify({
model: model, model: model,
prompt: prompt, prompt: prompt,
...(messages && { messages: messages }),
type: type, type: type,
stream: false stream: false
}) })

View File

@ -18,7 +18,7 @@
showControls showControls
} from '$lib/stores'; } from '$lib/stores';
import { blobToFile, findWordIndices } from '$lib/utils'; import { blobToFile, createMessagesList, findWordIndices } from '$lib/utils';
import { transcribeAudio } from '$lib/apis/audio'; import { transcribeAudio } from '$lib/apis/audio';
import { uploadFile } from '$lib/apis/files'; import { uploadFile } from '$lib/apis/files';
import { getTools } from '$lib/apis/tools'; import { getTools } from '$lib/apis/tools';
@ -606,7 +606,10 @@
const res = await generateAutoCompletion( const res = await generateAutoCompletion(
localStorage.token, localStorage.token,
selectedModelIds.at(0), selectedModelIds.at(0),
text text,
history?.currentId
? createMessagesList(history, history.currentId)
: null
).catch((error) => { ).catch((error) => {
console.log(error); console.log(error);
toast.error(error); toast.error(error);

View File

@ -7,6 +7,7 @@ export const AIAutocompletion = Extension.create({
addOptions() { addOptions() {
return { return {
generateCompletion: () => Promise.resolve(''), generateCompletion: () => Promise.resolve(''),
debounceTime: 1000,
} }
}, },
@ -45,6 +46,9 @@ export const AIAutocompletion = Extension.create({
}, },
addProseMirrorPlugins() { addProseMirrorPlugins() {
let debounceTimer = null;
let loading = false;
return [ return [
new Plugin({ new Plugin({
key: new PluginKey('aiAutocompletion'), key: new PluginKey('aiAutocompletion'),
@ -61,6 +65,8 @@ export const AIAutocompletion = Extension.create({
if (event.key === 'Tab') { if (event.key === 'Tab') {
if (!node.attrs['data-suggestion']) { if (!node.attrs['data-suggestion']) {
// Generate completion // Generate completion
if (loading) return true
loading = true
const prompt = node.textContent const prompt = node.textContent
this.options.generateCompletion(prompt).then(suggestion => { this.options.generateCompletion(prompt).then(suggestion => {
if (suggestion && suggestion.trim() !== '') { if (suggestion && suggestion.trim() !== '') {
@ -72,6 +78,8 @@ export const AIAutocompletion = Extension.create({
})) }))
} }
// If suggestion is empty or null, do nothing // If suggestion is empty or null, do nothing
}).finally(() => {
loading = false
}) })
} else { } else {
// Accept suggestion // Accept suggestion
@ -87,7 +95,9 @@ export const AIAutocompletion = Extension.create({
) )
} }
return true return true
} else if (node.attrs['data-suggestion']) { } else {
if (node.attrs['data-suggestion']) {
// Reset suggestion on any other key press // Reset suggestion on any other key press
dispatch(state.tr.setNodeMarkup($head.before(), null, { dispatch(state.tr.setNodeMarkup($head.before(), null, {
...node.attrs, ...node.attrs,
@ -97,6 +107,41 @@ export const AIAutocompletion = Extension.create({
})) }))
} }
// Set up debounce for AI generation
if (this.options.debounceTime !== null) {
clearTimeout(debounceTimer)
// Capture current position
const currentPos = $head.before()
debounceTimer = setTimeout(() => {
const newState = view.state
const newNode = newState.doc.nodeAt(currentPos)
// Check if the node still exists and is still a paragraph
if (newNode && newNode.type.name === 'paragraph') {
const prompt = newNode.textContent
if (prompt.trim() !== ''){
if (loading) return true
loading = true
this.options.generateCompletion(prompt).then(suggestion => {
if (suggestion && suggestion.trim() !== '') {
view.dispatch(newState.tr.setNodeMarkup(currentPos, null, {
...newNode.attrs,
class: 'ai-autocompletion',
'data-prompt': prompt,
'data-suggestion': suggestion,
}))
}
}).finally(() => {
loading = false
})
}
}
}, this.options.debounceTime)
}
}
return false return false
}, },
}, },

View File

@ -20,6 +20,11 @@
return; return;
} }
if (modelInfo.id === '') {
toast.error('Error: Model ID cannot be empty. Please enter a valid ID to proceed.');
return;
}
if (modelInfo) { if (modelInfo) {
const res = await createNewModel(localStorage.token, { const res = await createNewModel(localStorage.token, {
...modelInfo, ...modelInfo,