diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0a76626c1..a941627cf 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -999,6 +999,47 @@ Strictly return in JSON format: """ +AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", + "task.autocomplete.prompt_template", + os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task: +You are an **autocompletion system**. Your sole task is to generate concise, logical continuations for text provided within the `` tag. Additional guidance on the purpose, tone, or format will be included in a `` tag. + +Only output a continuation. If you are unsure how to proceed, output nothing. + +### **Instructions** +1. Analyze the `` to understand its structure, context, and flow. +2. Refer to the `` for any specific purpose or format (e.g., search queries, general, etc.). +3. Complete the text concisely and meaningfully without repeating or altering the original. +4. Do not introduce unrelated ideas or elaborate unnecessarily. + +### **Output Rules** +- Respond *only* with the continuation—no preamble or explanation. +- Ensure the continuation directly connects to the given text and adheres to the context. +- If unsure about completing, provide no output. + +### **Examples** + +**Example 1** +General +The sun was dipping below the horizon, painting the sky in shades of pink and orange as the cool breeze began to set in. +**Output**: A sense of calm spread through the air, and the first stars started to shimmer faintly above. + +**Example 2** +Search +How to prepare for a job interview +**Output**: effectively, including researching the company and practicing common questions. + +**Example 3** +Search +Best destinations for hiking in +**Output**: Europe, such as the Alps or the Scottish Highlands. +""" + + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", "task.tools.prompt_template", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 9c7d6f9e9..d25353f0e 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -113,5 +113,6 @@ class TASKS(str, Enum): TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" + AUTOCOMPLETION_GENERATION = "autocompletion_generation" FUNCTION_CALLING = "function_calling" MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d8c246dc3..311bf3968 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -89,6 +89,8 @@ from open_webui.config import ( DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, WEBHOOK_URL, WEBUI_AUTH, @@ -127,6 +129,7 @@ from open_webui.utils.task import ( rag_template, title_generation_template, query_generation_template, + autocomplete_generation_template, tags_generation_template, emoji_generation_template, moa_response_generation_template, @@ -215,6 +218,10 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) @@ -1982,6 +1989,73 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): return await generate_chat_completions(form_data=payload, user=user) +@app.post("/api/task/auto/completions") +async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): + context = form_data.get("context") + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating autocompletion using model {task_model_id} for user {user.email}" + ) + + if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + + content = autocomplete_generation_template( + template, form_data["messages"], context, {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.AUTOCOMPLETION_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + @app.post("/api/task/emoji/completions") async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 21a30e416..61e46f5ac 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -212,6 +212,29 @@ def emoji_generation_template( return template +def autocomplete_generation_template( + template: str, + messages: list[dict], + context: Optional[str] = None, + user: Optional[dict] = None, +) -> str: + prompt = get_last_user_message(messages) + template = template.replace("{{CONTEXT}}", context if context else "") + + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: