enh: retrieval query generation

This commit is contained in:
Timothy Jaeryang Baek
2024-11-19 02:24:32 -08:00
parent 09c6e4b92f
commit dbb67a12ca
7 changed files with 217 additions and 138 deletions

View File

@@ -348,15 +348,16 @@ export const generateEmoji = async (
return null;
};
export const generateSearchQuery = async (
export const generateQueries = async (
token: string = '',
model: string,
messages: object[],
prompt: string
prompt: string,
type?: string = 'web_search'
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
const res = await fetch(`${WEBUI_BASE_URL}/api/task/queries/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
@@ -366,7 +367,8 @@ export const generateSearchQuery = async (
body: JSON.stringify({
model: model,
messages: messages,
prompt: prompt
prompt: prompt,
type: type
})
})
.then(async (res) => {
@@ -385,7 +387,40 @@ export const generateSearchQuery = async (
throw error;
}
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
try {
// Step 1: Safely extract the response string
const response = res?.choices[0]?.message?.content ?? '';
// Step 2: Attempt to fix common JSON format issues like single quotes
const sanitizedResponse = response.replace(/['`]/g, '"'); // Convert single quotes to double quotes for valid JSON
// Step 3: Find the relevant JSON block within the response
const jsonStartIndex = sanitizedResponse.indexOf('{');
const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
// Step 5: Parse the JSON block
const parsed = JSON.parse(jsonResponse);
// Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array
if (parsed && parsed.queries) {
return Array.isArray(parsed.queries) ? parsed.queries : [];
} else {
return [];
}
}
// If no valid JSON block found, return an empty array
return [];
} catch (e) {
// Catch and safely return empty array on any parsing errors
console.error('Failed to parse response: ', e);
return [];
}
};
export const generateMoACompletion = async (

View File

@@ -26,8 +26,9 @@
TITLE_GENERATION_PROMPT_TEMPLATE: '',
TAGS_GENERATION_PROMPT_TEMPLATE: '',
ENABLE_TAGS_GENERATION: true,
ENABLE_SEARCH_QUERY: true,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: ''
ENABLE_SEARCH_QUERY_GENERATION: true,
ENABLE_RETRIEVAL_QUERY_GENERATION: true,
QUERY_GENERATION_PROMPT_TEMPLATE: ''
};
let promptSuggestions = [];
@@ -164,31 +165,35 @@
<hr class=" dark:border-gray-850 my-3" />
<div class="my-3 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Enable Retrieval Query Generation')}
</div>
<Switch bind:state={taskConfig.ENABLE_RETRIEVAL_QUERY_GENERATION} />
</div>
<div class="my-3 flex w-full items-center justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Enable Web Search Query Generation')}
</div>
<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY} />
<Switch bind:state={taskConfig.ENABLE_SEARCH_QUERY_GENERATION} />
</div>
{#if taskConfig.ENABLE_SEARCH_QUERY}
<div class="">
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Search Query Generation Prompt')}</div>
<div class="">
<div class=" mb-2.5 text-xs font-medium">{$i18n.t('Query Generation Prompt')}</div>
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t(
'Leave empty to use the default prompt, or enter a custom prompt'
)}
/>
</Tooltip>
</div>
{/if}
<Tooltip
content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
placement="top-start"
>
<Textarea
bind:value={taskConfig.QUERY_GENERATION_PROMPT_TEMPLATE}
placeholder={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
/>
</Tooltip>
</div>
</div>
<hr class=" dark:border-gray-850 my-3" />

View File

@@ -66,7 +66,7 @@
import {
chatCompleted,
generateTitle,
generateSearchQuery,
generateQueries,
chatAction,
generateMoACompletion,
generateTags
@@ -2047,17 +2047,17 @@
history.messages[responseMessageId] = responseMessage;
const prompt = userMessage.content;
let searchQuery = await generateSearchQuery(
let queries = await generateQueries(
localStorage.token,
model,
messages.filter((message) => message?.content?.trim()),
prompt
).catch((error) => {
console.log(error);
return prompt;
return [];
});
if (!searchQuery || searchQuery == '') {
if (queries.length === 0) {
responseMessage.statusHistory.push({
done: true,
error: true,
@@ -2068,6 +2068,8 @@
return;
}
const searchQuery = queries[0];
responseMessage.statusHistory.push({
done: false,
action: 'web_search',