refac: web search

This commit is contained in:
Timothy Jaeryang Baek 2024-12-24 17:52:57 -07:00
parent a074991d3a
commit 6b25139d4f
5 changed files with 179 additions and 100 deletions

View File

@ -856,6 +856,7 @@ async def chat_completion(
"session_id": form_data.pop("session_id", None),
"tool_ids": form_data.get("tool_ids", None),
"files": form_data.get("files", None),
"features": form_data.get("features", None),
}
form_data["metadata"] = metadata

View File

@ -1238,7 +1238,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
@router.post("/process/web/search")
def process_web_search(
async def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
):
try:
@ -1256,9 +1256,11 @@ def process_web_search(
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)
log.debug(f"web_results: {web_results}")
try:
collection_name = form_data.collection_name
if collection_name == "":
if collection_name == "" or collection_name is None:
collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
:63
]
@ -1269,8 +1271,7 @@ def process_web_search(
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
)
docs = loader.aload()
docs = loader.load()
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
return {

View File

@ -29,6 +29,7 @@ from open_webui.routers.tasks import (
generate_title,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.utils.webhook import post_webhook
@ -333,6 +334,149 @@ async def chat_completion_tools_handler(
return body, {"sources": sources}
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
event_emitter = extra_params["__event_emitter__"]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Generating search query",
"done": False,
},
}
)
messages = form_data["messages"]
user_message = get_last_user_message(messages)
queries = []
try:
res = await generate_queries(
request,
{
"model": form_data["model"],
"messages": messages,
"prompt": user_message,
"type": "web_search",
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
queries = json.loads(response)
queries = queries.get("queries", [])
except Exception as e:
queries = [response]
except Exception as e:
log.exception(e)
queries = [user_message]
if len(queries) == 0:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search query generated",
"done": True,
},
}
)
return
searchQuery = queries[0]
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
"done": False,
},
}
)
try:
results = await process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
}
),
user,
)
if results:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
"query": searchQuery,
"urls": results["filenames"],
"done": True,
},
}
)
files = form_data.get("files", [])
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
form_data["files"] = files
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"query": searchQuery,
"done": True,
"error": True,
},
}
)
except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"done": True,
"error": True,
},
}
)
return form_data
async def chat_completion_files_handler(
request: Request, body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
@ -456,7 +600,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
knowledge_files = []
for item in model_knowledge:
print(item)
if item.get("collection_name"):
knowledge_files.append(
{
@ -481,6 +624,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
files.extend(knowledge_files)
form_data["files"] = files
features = form_data.pop("features", None)
if features:
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
)
try:
form_data, flags = await chat_completion_filter_functions_handler(
request, form_data, model, extra_params

View File

@ -1419,11 +1419,8 @@
const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
scrollToBottom();
if (webSearchEnabled) {
await getWebSearchResults(model.id, parentId, responseMessageId);
}
await sendPromptSocket(model, responseMessageId, _chatId);
if (chatEventEmitter) clearInterval(chatEventEmitter);
} else {
toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
@ -1533,8 +1530,12 @@
: undefined
},
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
features: {
web_search: webSearchEnabled
},
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId,
@ -1751,94 +1752,6 @@
}
};
const getWebSearchResults = async (
model: string,
parentId: string,
responseMessageId: string
) => {
// TODO: move this to the backend
const responseMessage = history.messages[responseMessageId];
const userMessage = history.messages[parentId];
const messages = createMessagesList(history.currentId);
responseMessage.statusHistory = [
{
done: false,
action: 'web_search',
description: $i18n.t('Generating search query')
}
];
history.messages[responseMessageId] = responseMessage;
const prompt = userMessage.content;
let queries = await generateQueries(
localStorage.token,
model,
messages.filter((message) => message?.content?.trim()),
prompt
).catch((error) => {
console.log(error);
return [prompt];
});
if (queries.length === 0) {
responseMessage.statusHistory.push({
done: true,
error: true,
action: 'web_search',
description: $i18n.t('No search query generated')
});
history.messages[responseMessageId] = responseMessage;
return;
}
const searchQuery = queries[0];
responseMessage.statusHistory.push({
done: false,
action: 'web_search',
description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery })
});
history.messages[responseMessageId] = responseMessage;
const results = await processWebSearch(localStorage.token, searchQuery).catch((error) => {
console.log(error);
toast.error(error);
return null;
});
if (results) {
responseMessage.statusHistory.push({
done: true,
action: 'web_search',
description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }),
query: searchQuery,
urls: results.filenames
});
if (responseMessage?.files ?? undefined === undefined) {
responseMessage.files = [];
}
responseMessage.files.push({
collection_name: results.collection_name,
name: searchQuery,
type: 'web_search_results',
urls: results.filenames
});
history.messages[responseMessageId] = responseMessage;
} else {
responseMessage.statusHistory.push({
done: true,
error: true,
action: 'web_search',
description: 'No search results found'
});
history.messages[responseMessageId] = responseMessage;
}
};
const initChatHandler = async () => {
if (!$temporaryChatEnabled) {
chat = await createNewChat(localStorage.token, {

View File

@ -535,7 +535,14 @@
? 'shimmer'
: ''} text-base line-clamp-1 text-wrap"
>
{status?.description}
<!-- $i18n.t('Searched {{count}} sites') -->
{#if status?.description.includes('{{count}}')}
{$i18n.t(status?.description, {
count: status?.urls.length
})}
{:else}
{$i18n.t(status?.description)}
{/if}
</div>
</div>
</WebSearchResults>
@ -558,7 +565,14 @@
? 'shimmer'
: ''} text-gray-500 dark:text-gray-500 text-base line-clamp-1 text-wrap"
>
{status?.description}
<!-- $i18n.t(`Searching "{{searchQuery}}"`) -->
{#if status?.description.includes('{{searchQuery}}')}
{$i18n.t(status?.description, {
searchQuery: status?.query
})}
{:else}
{$i18n.t(status?.description)}
{/if}
</div>
</div>
{/if}