refac: web search settings

This commit is contained in:
Timothy J. Baek
2024-06-01 19:40:48 -07:00
parent ea6b8984ab
commit 912a704fdc
5 changed files with 300 additions and 86 deletions

View File

@@ -97,9 +97,11 @@ from config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
@@ -145,9 +147,12 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
@@ -351,23 +356,25 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enable": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
}
},
},
}
@@ -384,9 +391,11 @@ class YoutubeLoaderConfig(BaseModel):
class WebSearchConfig(BaseModel):
enable: bool
engine: Optional[str] = None
searxng_query_url: Optional[str] = None
google_pse_api_key: Optional[str] = None
google_pse_engine_id: Optional[str] = None
brave_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
@@ -394,11 +403,16 @@ class WebSearchConfig(BaseModel):
concurrent_requests: Optional[int] = None
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
@app.post("/config/update")
@@ -409,35 +423,36 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
if form_data.youtube is not None:
app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
if form_data.web is not None:
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web.web_loader_ssl_verification
)
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enable
app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
app.state.config.GOOGLE_PSE_ENGINE_ID = (
form_data.web.search.google_pse_engine_id
)
app.state.config.BRAVE_SEARCH_API_KEY = (
form_data.web.search.brave_search_api_key
)
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
return {
"status": True,
@@ -446,11 +461,26 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enable": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
@@ -690,7 +720,9 @@ def resolve_hostname(hostname):
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try:
try:
web_results = search_web(form_data.query)
web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
)
except Exception as e:
log.exception(e)
raise HTTPException(

View File

@@ -538,7 +538,7 @@ class RerankCompressor(BaseDocumentCompressor):
return final_results
def search_web(query: str) -> list[SearchResult]:
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
@@ -552,15 +552,34 @@ def search_web(query: str) -> list[SearchResult]:
"""
# TODO: add playwright to search the web
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
elif BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
elif SERPSTACK_API_KEY:
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
elif SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
if engine == "searxng":
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "google_pse":
if GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
else:
raise Exception(
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
)
elif engine == "brave":
if BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if SERPSTACK_API_KEY:
return search_serpstack(
SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
)
else:
raise Exception("No SERPSTACK_API_KEY found in environment variables")
elif engine == "serper":
if SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
else:
raise Exception("No search engine API key found in environment variables")

View File

@@ -773,6 +773,11 @@ ENABLE_RAG_WEB_SEARCH = PersistentConfig(
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
)
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
"RAG_WEB_SEARCH_ENGINE",
"rag.web.search.engine",
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",