feat: search query threshold

This commit is contained in:
Timothy J. Baek 2024-06-09 15:19:36 -07:00
parent 8b4867deb5
commit 8debb71197
3 changed files with 32 additions and 3 deletions

View File

@ -618,6 +618,11 @@ ADMIN_EMAIL = PersistentConfig(
) )
####################################
# TASKS
####################################
TASK_MODEL = PersistentConfig( TASK_MODEL = PersistentConfig(
"TASK_MODEL", "TASK_MODEL",
"task.model.default", "task.model.default",
@ -664,6 +669,15 @@ Question:
) )
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
),
)
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################

View File

@ -81,6 +81,7 @@ from config import (
TASK_MODEL_EXTERNAL, TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -144,6 +145,9 @@ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMP
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
) )
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.MODELS = {} app.state.MODELS = {}
@ -596,6 +600,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query") print("generate_search_query")
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
)
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in app.state.MODELS:
raise HTTPException( raise HTTPException(

View File

@ -56,6 +56,7 @@
import Messages from '$lib/components/chat/Messages.svelte'; import Messages from '$lib/components/chat/Messages.svelte';
import Navbar from '$lib/components/layout/Navbar.svelte'; import Navbar from '$lib/components/layout/Navbar.svelte';
import CallOverlay from './MessageInput/CallOverlay.svelte'; import CallOverlay from './MessageInput/CallOverlay.svelte';
import { error } from '@sveltejs/kit';
const i18n: Writable<i18nType> = getContext('i18n'); const i18n: Writable<i18nType> = getContext('i18n');
@ -506,7 +507,13 @@
messages = messages; messages = messages;
const prompt = history.messages[parentId].content; const prompt = history.messages[parentId].content;
let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt); let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch(
(error) => {
console.log(error);
return prompt;
}
);
if (!searchQuery) { if (!searchQuery) {
toast.warning($i18n.t('No search query generated')); toast.warning($i18n.t('No search query generated'));
responseMessage.status = { responseMessage.status = {
@ -516,8 +523,6 @@
description: 'No search query generated' description: 'No search query generated'
}; };
messages = messages; messages = messages;
searchQuery = prompt;
} }
responseMessage.status = { responseMessage.status = {