From 8debb711971f5744a8e98bf4b9d33853a208316a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 9 Jun 2024 15:19:36 -0700 Subject: [PATCH] feat: search query threshold --- backend/config.py | 14 ++++++++++++++ backend/main.py | 10 ++++++++++ src/lib/components/chat/Chat.svelte | 11 ++++++++--- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/backend/config.py b/backend/config.py index 2e718ce8c..27c4c1277 100644 --- a/backend/config.py +++ b/backend/config.py @@ -618,6 +618,11 @@ ADMIN_EMAIL = PersistentConfig( ) +#################################### +# TASKS +#################################### + + TASK_MODEL = PersistentConfig( "TASK_MODEL", "task.model.default", @@ -664,6 +669,15 @@ Question: ) +SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", + "task.search.prompt_length_threshold", + os.environ.get( + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", + 100, + ), +) + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/main.py b/backend/main.py index abd899614..75fce3a46 100644 --- a/backend/main.py +++ b/backend/main.py @@ -81,6 +81,7 @@ from config import ( TASK_MODEL_EXTERNAL, TITLE_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, AppConfig, ) from constants import ERROR_MESSAGES @@ -144,6 +145,9 @@ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMP app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE ) +app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD +) app.state.MODELS = {} @@ -596,6 +600,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): print("generate_search_query") + if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)", + ) + model_id = form_data["model"] if model_id not in app.state.MODELS: raise HTTPException( diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index c47e8d3a3..29031aeaa 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -56,6 +56,7 @@ import Messages from '$lib/components/chat/Messages.svelte'; import Navbar from '$lib/components/layout/Navbar.svelte'; import CallOverlay from './MessageInput/CallOverlay.svelte'; + import { error } from '@sveltejs/kit'; const i18n: Writable = getContext('i18n'); @@ -506,7 +507,13 @@ messages = messages; const prompt = history.messages[parentId].content; - let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt); + let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( + (error) => { + console.log(error); + return prompt; + } + ); + if (!searchQuery) { toast.warning($i18n.t('No search query generated')); responseMessage.status = { @@ -516,8 +523,6 @@ description: 'No search query generated' }; messages = messages; - - searchQuery = prompt; } responseMessage.status = {