From 81a8ad27620f794f469238e29b5f199d6a5915f5 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 30 Nov 2024 18:30:59 -0800 Subject: [PATCH] refac: autocomplete settings --- backend/open_webui/config.py | 11 +++++ backend/open_webui/main.py | 42 +++++++++++++++-- .../admin/Settings/Interface.svelte | 45 ++++++++++++++++--- src/lib/components/chat/MessageInput.svelte | 3 +- 4 files changed, 89 insertions(+), 12 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 06a5a6811..15d209941 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1004,6 +1004,17 @@ Strictly return in JSON format: """ +ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig( + "ENABLE_AUTOCOMPLETE_GENERATION", + "task.autocomplete.enable", + os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true", +) + +AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig( + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", + "task.autocomplete.input_max_length", + int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")), +) AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 40724fd30..b25ebbda1 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -89,6 +89,8 @@ from open_webui.config import ( DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, + ENABLE_AUTOCOMPLETE_GENERATION, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -210,6 +212,11 @@ app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION +app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH +) + app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE @@ -1672,6 +1679,8 @@ async def get_task_config(user=Depends(get_verified_user)): "TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -1685,6 +1694,8 @@ class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] TITLE_GENERATION_PROMPT_TEMPLATE: str + ENABLE_AUTOCOMPLETE_GENERATION: bool + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str ENABLE_TAGS_GENERATION: bool ENABLE_SEARCH_QUERY_GENERATION: bool @@ -1700,6 +1711,14 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) + + app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( + form_data.ENABLE_AUTOCOMPLETE_GENERATION + ) + app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ) + app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( form_data.TAGS_GENERATION_PROMPT_TEMPLATE ) @@ -1722,6 +1741,8 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -1991,6 +2012,23 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): @app.post("/api/task/auto/completions") async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): + if not app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Autocompletion generation is disabled", + ) + + type = form_data.get("type") + prompt = form_data.get("prompt") + messages = form_data.get("messages") + + if app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: + if len(prompt) > app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Input prompt exceeds maximum length of {app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", + ) + model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -2019,10 +2057,6 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use else: template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - type = form_data.get("type") - prompt = form_data.get("prompt") - messages = form_data.get("messages") - content = autocomplete_generation_template( template, prompt, messages, type, {"name": user.name} ) diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 2fee518ee..f0846a892 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -24,6 +24,8 @@ TASK_MODEL: '', TASK_MODEL_EXTERNAL: '', TITLE_GENERATION_PROMPT_TEMPLATE: '', + ENABLE_AUTOCOMPLETE_GENERATION: true, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1, TAGS_GENERATION_PROMPT_TEMPLATE: '', ENABLE_TAGS_GENERATION: true, ENABLE_SEARCH_QUERY_GENERATION: true, @@ -138,11 +140,42 @@ -
+
- {$i18n.t('Enable Tags Generation')} + {$i18n.t('Autocomplete Generation')} +
+ + + + +
+ + {#if taskConfig.ENABLE_AUTOCOMPLETE_GENERATION} +
+
+ {$i18n.t('Autocomplete Generation Input Max Length')} +
+ + + + +
+ {/if} + +
+ +
+
+ {$i18n.t('Tags Generation')}
@@ -166,11 +199,11 @@
{/if} -
+
- {$i18n.t('Enable Retrieval Query Generation')} + {$i18n.t('Retrieval Query Generation')}
@@ -178,7 +211,7 @@
- {$i18n.t('Enable Web Search Query Generation')} + {$i18n.t('Web Search Query Generation')}
@@ -201,7 +234,7 @@
-
+
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index a8f61ea52..4df2b552a 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -612,12 +612,11 @@ : null ).catch((error) => { console.log(error); - toast.error(error); + return null; }); console.log(res); - return res; }} on:keydown={async (e) => {