feat: add in-message progress indicator for web search

This commit is contained in:
Jun Siang Cheah 2024-05-12 15:21:03 +08:00
parent d45804d7f4
commit 3baeda7edc
4 changed files with 138 additions and 72 deletions

View File

@ -519,9 +519,7 @@ export const runWebSearch = async (
query: string,
collection_name?: string
): Promise<SearchDocument | undefined> => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/websearch`, {
return await fetch(`${RAG_API_BASE_URL}/websearch`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@ -529,7 +527,7 @@ export const runWebSearch = async (
},
body: JSON.stringify({
query,
collection_name
collection_name: collection_name ?? ''
})
})
.then(async (res) => {
@ -538,15 +536,8 @@ export const runWebSearch = async (
})
.catch((err) => {
console.log(err);
error = err.detail;
return undefined;
});
if (error) {
throw error;
}
return res;
};
export interface SearchDocument {

View File

@ -369,6 +369,62 @@
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-ol:p-0 prose-li:-mb-4 whitespace-pre-line"
>
<div>
{#if message.progress}
<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
<div>
<button
class="h-16 flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left"
type="button"
>
<div class="p-2.5 bg-red-400 text-white rounded-lg">
<svg
class=" w-6 h-6 translate-y-[0.5px]"
fill="currentColor"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_qM83 {
animation: spinner_8HQG 1.05s infinite;
}
.spinner_oXPr {
animation-delay: 0.1s;
}
.spinner_ZTLf {
animation-delay: 0.2s;
}
@keyframes spinner_8HQG {
0%,
57.14% {
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
transform: translate(0);
}
28.57% {
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
transform: translateY(-6px);
}
100% {
transform: translate(0);
}
}
</style><circle class="spinner_qM83" cx="4" cy="12" r="2.5" /><circle
class="spinner_qM83 spinner_oXPr"
cx="12"
cy="12"
r="2.5"
/><circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="2.5" /></svg
>
</div>
<div class="flex flex-col justify-center -space-y-0.5">
<div class=" dark:text-gray-100 text-sm font-medium line-clamp-2 text-wrap">
{message.progress}
</div>
</div>
</button>
</div>
</div>
{/if}
{#if edit === true}
<div class=" w-full">
<textarea

View File

@ -31,7 +31,11 @@
updateChatById
} from '$lib/apis/chats';
import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag';
import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
import {
generateOpenAIChatCompletion,
generateSearchQuery,
generateTitle
} from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte';
@ -286,36 +290,7 @@
}
if (useWebSearch) {
// TODO: Toasts are temporary indicators for web search
toast.info($i18n.t('Generating search query'));
const searchQuery = await generateChatSearchQuery(prompt);
if (searchQuery) {
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
const searchDocUuid = uuidv4();
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
if (searchDocument) {
const parentMessage = history.messages[parentId];
if (!parentMessage.files) {
parentMessage.files = [];
}
parentMessage.files.push({
collection_name: searchDocument.collection_name,
name: searchQuery,
type: 'doc',
upload_status: true,
error: ""
});
// Find message in messages and update it
const messageIndex = messages.findIndex((message) => message.id === parentId);
if (messageIndex !== -1) {
messages[messageIndex] = parentMessage;
}
} else {
toast.warning($i18n.t('No search results found'));
}
} else {
toast.warning($i18n.t('No search query generated'));
}
await runWebSearchForPrompt(parentId, responseMessageId, prompt);
}
if (model?.external) {
@ -332,6 +307,41 @@
await chats.set(await getChatList(localStorage.token));
};
const runWebSearchForPrompt = async (parentId: string, responseId: string, prompt: string) => {
const responseMessage = history.messages[responseId];
responseMessage.progress = $i18n.t('Generating search query');
messages = messages;
const searchQuery = await generateChatSearchQuery(prompt);
if (!searchQuery) {
toast.warning($i18n.t('No search query generated'));
responseMessage.progress = undefined;
messages = messages;
return;
}
responseMessage.progress = $i18n.t("Searching the web for '{{searchQuery}}'", { searchQuery });
messages = messages;
const searchDocument = await runWebSearch(localStorage.token, searchQuery);
if (!searchDocument) {
toast.warning($i18n.t('No search results found'));
responseMessage.progress = undefined;
messages = messages;
return;
}
const parentMessage = history.messages[parentId];
if (!parentMessage.files) {
parentMessage.files = [];
}
parentMessage.files.push({
collection_name: searchDocument!.collection_name,
name: searchQuery,
type: 'doc',
upload_status: true,
error: ''
});
responseMessage.progress = undefined;
messages = messages;
};
const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
model = model.id;
const responseMessage = history.messages[responseMessageId];

View File

@ -291,36 +291,7 @@
}
if (useWebSearch) {
// TODO: Toasts are temporary indicators for web search
toast.info($i18n.t('Generating search query'));
const searchQuery = await generateChatSearchQuery(prompt);
if (searchQuery) {
toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
const searchDocUuid = uuidv4();
const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
if (searchDocument) {
const parentMessage = history.messages[parentId];
if (!parentMessage.files) {
parentMessage.files = [];
}
parentMessage.files.push({
collection_name: searchDocument.collection_name,
name: searchQuery,
type: 'doc',
upload_status: true,
error: ""
});
// Find message in messages and update it
const messageIndex = messages.findIndex((message) => message.id === parentId);
if (messageIndex !== -1) {
messages[messageIndex] = parentMessage;
}
} else {
toast.warning($i18n.t('No search results found'));
}
} else {
toast.warning($i18n.t('No search query generated'));
}
await runWebSearchForPrompt(parentId, responseMessageId, prompt);
}
if (model?.external) {
@ -337,6 +308,44 @@
await chats.set(await getChatList(localStorage.token));
};
const runWebSearchForPrompt = async (parentId: string, responseId: string, prompt: string) => {
const responseMessage = history.messages[responseId];
responseMessage.progress = $i18n.t('Generating search query');
messages = messages;
const searchQuery = await generateChatSearchQuery(prompt);
if (!searchQuery) {
toast.warning($i18n.t('No search query generated'));
responseMessage.progress = undefined;
messages = messages;
return;
}
responseMessage.progress = $i18n.t("Searching the web for '{{searchQuery}}'", { searchQuery });
messages = messages;
const searchDocument = await runWebSearch(
localStorage.token,
searchQuery,
);
if (!searchDocument) {
toast.warning($i18n.t('No search results found'));
responseMessage.progress = undefined;
messages = messages;
return;
}
const parentMessage = history.messages[parentId];
if (!parentMessage.files) {
parentMessage.files = [];
}
parentMessage.files.push({
collection_name: searchDocument!.collection_name,
name: searchQuery,
type: 'doc',
upload_status: true,
error: ''
});
responseMessage.progress = undefined;
messages = messages;
};
const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
model = model.id;
const responseMessage = history.messages[responseMessageId];