enh: retrieval query generation

This commit is contained in:
Timothy Jaeryang Baek 2024-11-19 02:24:32 -08:00
parent 09c6e4b92f
commit dbb67a12ca
7 changed files with 217 additions and 138 deletions

View File

@ -177,35 +177,34 @@ def merge_and_sort_query_results(
def query_collection( def query_collection(
collection_names: list[str], collection_names: list[str],
query: str, queries: list[str],
embedding_function, embedding_function,
k: int, k: int,
) -> dict: ) -> dict:
results = [] results = []
query_embedding = embedding_function(query) for query in queries:
query_embedding = embedding_function(query)
for collection_name in collection_names: for collection_name in collection_names:
if collection_name: if collection_name:
try: try:
result = query_doc( result = query_doc(
collection_name=collection_name, collection_name=collection_name,
k=k, k=k,
query_embedding=query_embedding, query_embedding=query_embedding,
) )
if result is not None: if result is not None:
results.append(result.model_dump()) results.append(result.model_dump())
except Exception as e: except Exception as e:
log.exception(f"Error when querying the collection: {e}") log.exception(f"Error when querying the collection: {e}")
else: else:
pass pass
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search( def query_collection_with_hybrid_search(
collection_names: list[str], collection_names: list[str],
query: str, queries: list[str],
embedding_function, embedding_function,
k: int, k: int,
reranking_function, reranking_function,
@ -215,15 +214,16 @@ def query_collection_with_hybrid_search(
error = False error = False
for collection_name in collection_names: for collection_name in collection_names:
try: try:
result = query_doc_with_hybrid_search( for query in queries:
collection_name=collection_name, result = query_doc_with_hybrid_search(
query=query, collection_name=collection_name,
embedding_function=embedding_function, query=query,
k=k, embedding_function=embedding_function,
reranking_function=reranking_function, k=k,
r=r, reranking_function=reranking_function,
) r=r,
results.append(result) )
results.append(result)
except Exception as e: except Exception as e:
log.exception( log.exception(
"Error when querying the collection with " f"hybrid_search: {e}" "Error when querying the collection with " f"hybrid_search: {e}"
@ -309,15 +309,14 @@ def get_embedding_function(
def get_rag_context( def get_rag_context(
files, files,
messages, queries,
embedding_function, embedding_function,
k, k,
reranking_function, reranking_function,
r, r,
hybrid_search, hybrid_search,
): ):
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}") log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
@ -359,7 +358,7 @@ def get_rag_context(
try: try:
context = query_collection_with_hybrid_search( context = query_collection_with_hybrid_search(
collection_names=collection_names, collection_names=collection_names,
query=query, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function, reranking_function=reranking_function,
@ -374,7 +373,7 @@ def get_rag_context(
if (not hybrid_search) or (context is None): if (not hybrid_search) or (context is None):
context = query_collection( context = query_collection(
collection_names=collection_names, collection_names=collection_names,
query=query, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
) )

View File

@ -941,19 +941,49 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true", os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
) )
ENABLE_SEARCH_QUERY = PersistentConfig(
"ENABLE_SEARCH_QUERY", ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
"task.search.enable", "ENABLE_SEARCH_QUERY_GENERATION",
os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true", "task.query.search.enable",
os.environ.get("ENABLE_SEARCH_QUERY_GENERATION", "True").lower() == "true",
)
ENABLE_RETRIEVAL_QUERY_GENERATION = PersistentConfig(
"ENABLE_RETRIEVAL_QUERY_GENERATION",
"task.query.retrieval.enable",
os.environ.get("ENABLE_RETRIEVAL_QUERY_GENERATION", "True").lower() == "true",
) )
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", "QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template", "task.query.prompt_template",
os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""), os.environ.get("QUERY_GENERATION_PROMPT_TEMPLATE", ""),
) )
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
### Guidelines:
- Respond exclusively with a JSON object.
- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
- If no search query is necessary, output should be: { "queries": [] }
- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
- Be concise, focusing strictly on composing search queries with no additional commentary or text.
- When in doubt, prefer to suggest a search for comprehensiveness.
- Today's date is: {{CURRENT_DATE}}
### Output:
JSON format: {
"queries": ["query1", "query2"]
}
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>
"""
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
@ -1127,27 +1157,6 @@ RAG_TEXT_SPLITTER = PersistentConfig(
) )
ENABLE_RAG_QUERY_GENERATION = PersistentConfig(
"ENABLE_RAG_QUERY_GENERATION",
"rag.query_generation.enable",
os.environ.get("ENABLE_RAG_QUERY_GENERATION", "False").lower() == "true",
)
DEFAULT_RAG_QUERY_GENERATION_TEMPLATE = """Given the user's message and interaction history, decide if a file search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt.
User Message:
{{prompt:end:4000}}
Interaction History:
{{MESSAGES:END:6}}
Search Query:"""
RAG_QUERY_GENERATION_TEMPLATE = PersistentConfig(
"RAG_QUERY_GENERATION_TEMPLATE",
"rag.query_generation.template",
os.environ.get("RAG_QUERY_GENERATION_TEMPLATE", ""),
)
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_NAME = PersistentConfig( TIKTOKEN_ENCODING_NAME = PersistentConfig(
"TIKTOKEN_ENCODING_NAME", "TIKTOKEN_ENCODING_NAME",

View File

@ -78,11 +78,13 @@ from open_webui.config import (
ENV, ENV,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
OAUTH_PROVIDERS, OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR, STATIC_DIR,
TASK_MODEL, TASK_MODEL,
TASK_MODEL_EXTERNAL, TASK_MODEL_EXTERNAL,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
TITLE_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@ -122,7 +124,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import ( from open_webui.utils.task import (
moa_response_generation_template, moa_response_generation_template,
tags_generation_template, tags_generation_template,
search_query_generation_template, query_generation_template,
emoji_generation_template, emoji_generation_template,
title_generation_template, title_generation_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
@ -206,10 +208,9 @@ app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@ -492,14 +493,41 @@ async def chat_completion_tools_handler(
return body, {"contexts": contexts, "citations": citations} return body, {"contexts": contexts, "citations": citations}
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]: async def chat_completion_files_handler(
body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
contexts = [] contexts = []
citations = [] citations = []
try:
queries_response = await generate_queries(
{
"model": body["model"],
"messages": body["messages"],
"type": "retrieval",
},
user,
)
queries_response = queries_response["choices"][0]["message"]["content"]
try:
queries_response = json.loads(queries_response)
except Exception as e:
queries_response = {"queries": []}
queries = queries_response.get("queries", [])
except Exception as e:
queries = []
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
print(f"{queries=}")
if files := body.get("metadata", {}).get("files", None): if files := body.get("metadata", {}).get("files", None):
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
files=files, files=files,
messages=body["messages"], queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K, k=retrieval_app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf, reranking_function=retrieval_app.state.sentence_transformer_rf,
@ -643,7 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
log.exception(e) log.exception(e)
try: try:
body, flags = await chat_completion_files_handler(body) body, flags = await chat_completion_files_handler(body, user)
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))
except Exception as e: except Exception as e:
@ -1579,8 +1607,9 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
} }
@ -1591,8 +1620,9 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str TITLE_GENERATION_PROMPT_TEMPLATE: str
TAGS_GENERATION_PROMPT_TEMPLATE: str TAGS_GENERATION_PROMPT_TEMPLATE: str
ENABLE_TAGS_GENERATION: bool ENABLE_TAGS_GENERATION: bool
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_SEARCH_QUERY: bool ENABLE_RETRIEVAL_QUERY_GENERATION: bool
QUERY_GENERATION_PROMPT_TEMPLATE: str
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@ -1607,11 +1637,16 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
form_data.TAGS_GENERATION_PROMPT_TEMPLATE form_data.TAGS_GENERATION_PROMPT_TEMPLATE
) )
app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( form_data.ENABLE_SEARCH_QUERY_GENERATION
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE )
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
)
app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
) )
app.state.config.ENABLE_SEARCH_QUERY = form_data.ENABLE_SEARCH_QUERY
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
) )
@ -1622,8 +1657,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY, "ENABLE_RETRIEVAL_QUERY_GENERATION": app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
} }
@ -1799,14 +1835,22 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions") @app.post("/api/task/queries/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query") print("generate_queries")
if not app.state.config.ENABLE_SEARCH_QUERY: type = form_data.get("type")
raise HTTPException( if type == "web_search":
status_code=status.HTTP_400_BAD_REQUEST, if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
detail=f"Search query generation is disabled", raise HTTPException(
) status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Search query generation is disabled",
)
elif type == "retrieval":
if not app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Query generation is disabled",
)
model_list = await get_all_models() model_list = await get_all_models()
models = {model["id"]: model for model in model_list} models = {model["id"]: model for model in model_list}
@ -1830,20 +1874,12 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model = models[task_model_id] model = models[task_model_id]
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
else: else:
template = """Given the user's message and interaction history, decide if a web search is necessary. You must be concise and exclusively provide a search query if one is necessary. Refrain from verbose responses or any additional commentary. Prefer suggesting a search if uncertain to provide comprehensive or updated information. If a search isn't needed at all, respond with an empty string. Default to a search query when in doubt. Today's date is {{CURRENT_DATE}}. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
User Message: content = query_generation_template(
{{prompt:end:4000}}
Interaction History:
{{MESSAGES:END:6}}
Search Query:"""
content = search_query_generation_template(
template, form_data["messages"], {"name": user.name} template, form_data["messages"], {"name": user.name}
) )
@ -1851,13 +1887,6 @@ Search Query:"""
"model": task_model_id, "model": task_model_id,
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
**(
{"max_tokens": 30}
if models[task_model_id]["owned_by"] == "ollama"
else {
"max_completion_tokens": 30,
}
),
"metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data},
} }
log.debug(payload) log.debug(payload)

View File

@ -163,7 +163,7 @@ def emoji_generation_template(
return template return template
def search_query_generation_template( def query_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None template: str, messages: list[dict], user: Optional[dict] = None
) -> str: ) -> str:
prompt = get_last_user_message(messages) prompt = get_last_user_message(messages)

View File

@ -348,15 +348,16 @@ export const generateEmoji = async (
return null; return null;
}; };
export const generateSearchQuery = async ( export const generateQueries = async (
token: string = '', token: string = '',
model: string, model: string,
messages: object[], messages: object[],
prompt: string prompt: string,
type?: string = 'web_search'
) => { ) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -366,7 +367,8 @@ export const generateSearchQuery = async (
body: JSON.stringify({ body: JSON.stringify({
model: model, model: model,
messages: messages, messages: messages,
prompt: prompt prompt: prompt,
type: type
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -385,7 +387,40 @@ export const generateSearchQuery = async (
throw error; throw error;
} }
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
try {
// Step 1: Safely extract the response string
const response = res?.choices[0]?.message?.content ?? '';
// Step 2: Attempt to fix common JSON format issues like single quotes
const sanitizedResponse = response.replace(/['`]/g, '"'); // Convert single quotes to double quotes for valid JSON
// Step 3: Find the relevant JSON block within the response
const jsonStartIndex = sanitizedResponse.indexOf('{');
const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
// Step 5: Parse the JSON block
const parsed = JSON.parse(jsonResponse);
// Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array
if (parsed && parsed.queries) {
return Array.isArray(parsed.queries) ? parsed.queries : [];
} else {
return [];
}
}
// If no valid JSON block found, return an empty array
return [];
} catch (e) {
// Catch and safely return empty array on any parsing errors
console.error('Failed to parse response: ', e);
return [];
}
}; };
export const generateMoACompletion = async ( export const generateMoACompletion = async (

View File

@ -26,8 +26,9 @@
TITLE_GENERATION_PROMPT_TEMPLATE: '', TITLE_GENERATION_PROMPT_TEMPLATE: '',
TAGS_GENERATION_PROMPT_TEMPLATE: '', TAGS_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_TAGS_GENERATION: true, ENABLE_TAGS_GENERATION: true,
ENABLE_SEARCH_QUERY: true, ENABLE_SEARCH_QUERY_GENERATION: true,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: '' ENABLE_RETRIEVAL_QUERY_GENERATION: true,
QUERY_GENERATION_PROMPT_TEMPLATE: ''
}; };
let promptSuggestions = []; let promptSuggestions = [];
@ -164,31 +165,35 @@
<hr class=" dark:border-gray-850 my-3" /> <hr class=" dark:border-gray-850 my-3" />
<div class="my-3 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Enable Retrieval Query Generation')}
</div>
<Switch bind:state={taskConfig.ENABLE_RETRIEVAL_QUERY_GENERATION} />
</div>
<div class="my-3 flex w-full items-center justify-between"> <div class="my-3 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Enable Web Search Query Generation')} {$i18n.t('Enable Web Search Query Generation')}
</div> </div>
<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY} /> <Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY_GENERATION} />
</div> </div>
{#if taskConfig.ENABLE_SEARCH_QUERY} <div class="">
<div class=""> <div class=" mb-2.5 text-xs font-medium">{$i18n.t('Query Generation Prompt')}</div>
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
<Tooltip <Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')} content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start" placement="top-start"
> >
<Textarea <Textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE} bind:value={taskConfig.QUERY_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t( placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
'Leave empty to use the default prompt, or enter a custom prompt' />
)} </Tooltip>
/> </div>
</Tooltip>
</div>
{/if}
</div> </div>
<hr class=" dark:border-gray-850 my-3" /> <hr class=" dark:border-gray-850 my-3" />

View File

@ -66,7 +66,7 @@
import { import {
chatCompleted, chatCompleted,
generateTitle, generateTitle,
generateSearchQuery, generateQueries,
chatAction, chatAction,
generateMoACompletion, generateMoACompletion,
generateTags generateTags
@ -2047,17 +2047,17 @@
history.messages[responseMessageId] = responseMessage; history.messages[responseMessageId] = responseMessage;
const prompt = userMessage.content; const prompt = userMessage.content;
let searchQuery = await generateSearchQuery( let queries = await generateQueries(
localStorage.token, localStorage.token,
model, model,
messages.filter((message) => message?.content?.trim()), messages.filter((message) => message?.content?.trim()),
prompt prompt
).catch((error) => { ).catch((error) => {
console.log(error); console.log(error);
return prompt; return [];
}); });
if (!searchQuery || searchQuery == '') { if (queries.length === 0) {
responseMessage.statusHistory.push({ responseMessage.statusHistory.push({
done: true, done: true,
error: true, error: true,
@ -2068,6 +2068,8 @@
return; return;
} }
const searchQuery = queries[0];
responseMessage.statusHistory.push({ responseMessage.statusHistory.push({
done: false, done: false,
action: 'web_search', action: 'web_search',