refac: web search

This commit is contained in:
Timothy J. Baek 2024-06-01 19:52:12 -07:00
parent 912a704fdc
commit 999d2bc21b
5 changed files with 119 additions and 92 deletions

View File

@ -59,9 +59,16 @@ from apps.rag.utils import (
query_doc_with_hybrid_search, query_doc_with_hybrid_search,
query_collection, query_collection,
query_collection_with_hybrid_search, query_collection_with_hybrid_search,
search_web,
) )
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
calculate_sha256_string, calculate_sha256_string,
@ -716,19 +723,78 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if engine == "searxng":
if app.state.config.SEARXNG_QUERY_URL:
return search_searxng(app.state.config.SEARXNG_QUERY_URL, query)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "google_pse":
if (
app.state.config.GOOGLE_PSE_API_KEY
and app.state.config.GOOGLE_PSE_ENGINE_ID
):
return search_google_pse(
app.state.config.GOOGLE_PSE_API_KEY,
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
)
else:
raise Exception(
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
)
elif engine == "brave":
if app.state.config.BRAVE_SEARCH_API_KEY:
return search_brave(app.state.config.BRAVE_SEARCH_API_KEY, query)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
app.state.config.SERPSTACK_API_KEY,
query,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
raise Exception("No SERPSTACK_API_KEY found in environment variables")
elif engine == "serper":
if app.state.config.SERPER_API_KEY:
return search_serper(app.state.config.SERPER_API_KEY, query)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
else:
raise Exception("No search engine API key found in environment variables")
@app.post("/web/search") @app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try:
try: try:
web_results = search_web( web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
print(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR, detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
) )
try:
urls = [result.link for result in web_results] urls = [result.link for result in web_results]
loader = get_web_loader(urls) loader = get_web_loader(urls)
data = loader.load() data = loader.load()

View File

@ -20,12 +20,7 @@ from langchain.retrievers import (
from typing import Optional from typing import Optional
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
CHROMA_CLIENT, CHROMA_CLIENT,
@ -536,50 +531,3 @@ class RerankCompressor(BaseDocumentCompressor):
) )
final_results.append(doc) final_results.append(doc)
return final_results return final_results
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if engine == "searxng":
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "google_pse":
if GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
else:
raise Exception(
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
)
elif engine == "brave":
if BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if SERPSTACK_API_KEY:
return search_serpstack(
SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
)
else:
raise Exception("No SERPSTACK_API_KEY found in environment variables")
elif engine == "serper":
if SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
else:
raise Exception("No search engine API key found in environment variables")

View File

@ -82,5 +82,5 @@ class ERROR_MESSAGES(str, Enum):
) )
WEB_SEARCH_ERROR = ( WEB_SEARCH_ERROR = (
"Oops! Something went wrong while searching the web. Please try again later." lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
) )

View File

@ -518,8 +518,10 @@ export const runWebSearch = async (
token: string, token: string,
query: string, query: string,
collection_name?: string collection_name?: string
): Promise<SearchDocument | undefined> => { ): Promise<SearchDocument | null> => {
return await fetch(`${RAG_API_BASE_URL}/web/search`, { let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/web/search`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -536,8 +538,15 @@ export const runWebSearch = async (
}) })
.catch((err) => { .catch((err) => {
console.log(err); console.log(err);
return undefined; error = err.detail;
return null;
}); });
if (error) {
throw error;
}
return res;
}; };
export interface SearchDocument { export interface SearchDocument {

View File

@ -473,19 +473,14 @@
}; };
messages = messages; messages = messages;
const results = await runWebSearch(localStorage.token, searchQuery); const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => {
if (results === undefined) { console.log(error);
toast.warning($i18n.t('No search results found')); toast.error(error);
responseMessage.status = {
...responseMessage.status,
done: true,
error: true,
description: 'No search results found'
};
messages = messages;
return;
}
return null;
});
if (results) {
responseMessage.status = { responseMessage.status = {
...responseMessage.status, ...responseMessage.status,
done: true, done: true,
@ -505,6 +500,15 @@
}); });
messages = messages; messages = messages;
} else {
responseMessage.status = {
...responseMessage.status,
done: true,
error: true,
description: 'No search results found'
};
messages = messages;
}
}; };
const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => { const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {