feat: autocomplete backend endpoint

This commit is contained in:
Timothy Jaeryang Baek 2024-11-28 23:53:52 -08:00
parent 28ce102a79
commit 0e8e9820d0
4 changed files with 139 additions and 0 deletions

View File

@ -999,6 +999,47 @@ Strictly return in JSON format:
"""
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE",
"task.autocomplete.prompt_template",
os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task:
You are an **autocompletion system**. Your sole task is to generate concise, logical continuations for text provided within the `<text>` tag. Additional guidance on the purpose, tone, or format will be included in a `<context>` tag.
Only output a continuation. If you are unsure how to proceed, output nothing.
### **Instructions**
1. Analyze the `<text>` to understand its structure, context, and flow.
2. Refer to the `<context>` for any specific purpose or format (e.g., search queries, general, etc.).
3. Complete the text concisely and meaningfully without repeating or altering the original.
4. Do not introduce unrelated ideas or elaborate unnecessarily.
### **Output Rules**
- Respond *only* with the continuationno preamble or explanation.
- Ensure the continuation directly connects to the given text and adheres to the context.
- If unsure about completing, provide no output.
### **Examples**
**Example 1**
<context>General</context>
<text>The sun was dipping below the horizon, painting the sky in shades of pink and orange as the cool breeze began to set in.</text>
**Output**: A sense of calm spread through the air, and the first stars started to shimmer faintly above.
**Example 2**
<context>Search</context>
<text>How to prepare for a job interview</text>
**Output**: effectively, including researching the company and practicing common questions.
**Example 3**
<context>Search</context>
<text>Best destinations for hiking in</text>
**Output**: Europe, such as the Alps or the Scottish Highlands.
"""
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",

View File

@ -113,5 +113,6 @@ class TASKS(str, Enum):
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
AUTOCOMPLETION_GENERATION = "autocompletion_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"

View File

@ -89,6 +89,8 @@ from open_webui.config import (
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
@ -127,6 +129,7 @@ from open_webui.utils.task import (
rag_template,
title_generation_template,
query_generation_template,
autocomplete_generation_template,
tags_generation_template,
emoji_generation_template,
moa_response_generation_template,
@ -215,6 +218,10 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
@ -1982,6 +1989,73 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/auto/completions")
async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
context = form_data.get("context")
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating autocompletion using model {task_model_id} for user {user.email}"
)
if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
content = autocomplete_generation_template(
template, form_data["messages"], context, {"name": user.name}
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
"task": str(TASKS.AUTOCOMPLETION_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Handle pipeline filters
try:
payload = filter_pipeline(payload, user, models)
except Exception as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/emoji/completions")
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):

View File

@ -212,6 +212,29 @@ def emoji_generation_template(
return template
def autocomplete_generation_template(
template: str,
messages: list[dict],
context: Optional[str] = None,
user: Optional[dict] = None,
) -> str:
prompt = get_last_user_message(messages)
template = template.replace("{{CONTEXT}}", context if context else "")
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str: