From ba54452ab194f7c281e07152eb9eaaa27b6c6333 Mon Sep 17 00:00:00 2001 From: Maytown Date: Wed, 14 May 2025 17:33:11 +0200 Subject: [PATCH] Fix: adjusted to handle both default and individual rag settings --- backend/open_webui/routers/retrieval.py | 76 ++++++++++++++++--------- src/lib/apis/retrieval/index.ts | 72 ++++++++++++++++++----- 2 files changed, 107 insertions(+), 41 deletions(-) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 78d207d2a..49f739ce5 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -190,6 +190,8 @@ class ProcessUrlForm(CollectionNameForm): class SearchForm(BaseModel): query: str +class CollectionForm(BaseModel): + knowledge_id: Optional[str] = None @router.get("/") async def get_status(request: Request): @@ -206,13 +208,15 @@ async def get_status(request: Request): @router.post("/embedding") -async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)): +async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)): """ Retrieve the embedding configuration. If DEFAULT_RAG_SETTINGS is True, return the default embedding settings. Otherwise, return the embedding configuration stored in the database. """ - knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Return the embedding configuration from the database rag_config = knowledge_base.rag_config @@ -249,13 +253,15 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF @router.post("/reranking") -async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)): +async def get_reranking_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)): """ Retrieve the reranking configuration. If DEFAULT_RAG_SETTINGS is True, return the default reranking settings. Otherwise, return the reranking configuration stored in the database. """ - knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Return the reranking configuration from the database rag_config = knowledge_base.rag_config @@ -287,7 +293,7 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 - collection_name: Optional[str] = None + knowledge_id: Optional[str] = None @router.post("/embedding/update") @@ -300,7 +306,7 @@ async def update_embedding_config( Otherwise, update the RAG configuration in the database for the user's knowledge base. """ try: - knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database rag_config = knowledge_base.rag_config @@ -312,14 +318,13 @@ async def update_embedding_config( rag_config["embedding_model"] = form_data.embedding_model rag_config["embedding_batch_size"] = form_data.embedding_batch_size - if form_data.openai_config is not None: - rag_config["openai_config"] = { + + rag_config["openai_config"] = { "url": form_data.openai_config.url, "key": form_data.openai_config.key, } - if form_data.ollama_config is not None: - rag_config["ollama_config"] = { + rag_config["ollama_config"] = { "url": form_data.ollama_config.url, "key": form_data.ollama_config.key, } @@ -348,8 +353,8 @@ async def update_embedding_config( ) # Save the updated configuration to the database - Knowledges.update_knowledge_data_by_id( - id=form_data.collection_name, data={"rag_config": rag_config} + Knowledges.update_rag_config_by_id( + id=form_data.knowledge_id, rag_config=rag_config ) return { @@ -428,7 +433,7 @@ async def update_embedding_config( class RerankingModelUpdateForm(BaseModel): reranking_model: str - collection_name: Optional[str] + knowledge_id: Optional[str] = None @router.post("/reranking/update") async def update_reranking_config( @@ -440,16 +445,19 @@ async def update_reranking_config( Otherwise, update the RAG configuration in the database for the user's knowledge base. """ try: - knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) + + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database + rag_config = knowledge_base.rag_config log.info( - f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" + f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}" ) - rag_config["reranking_model"] = form_data.reranking_model - Knowledges.update_knowledge_data_by_id( - id=knowledge_base.id, data={"rag_config": rag_config} + rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None + Knowledges.update_rag_config_by_id( + id=form_data.knowledge_id, rag_config=rag_config ) try: if not request.app.state.rf.get(rag_config["reranking_model"]): @@ -500,13 +508,15 @@ async def update_reranking_config( @router.post("/config") -async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)): +async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)): """ Retrieve the full RAG configuration. If DEFAULT_RAG_SETTINGS is True, return the default settings. Otherwise, return the RAG configuration stored in the database. """ - knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) + + knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id) + if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Return the RAG configuration from the database rag_config = knowledge_base.rag_config @@ -764,18 +774,26 @@ class ConfigForm(BaseModel): # Web search settings web: Optional[WebConfig] = None + # knowledge base ID + knowledge_id: Optional[str] = None + + +class ConfigFormWrapper(BaseModel): + form_data: ConfigForm + @router.post("/config/update") async def update_rag_config( - request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user) + request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user) ): """ Update the RAG configuration. If DEFAULT_RAG_SETTINGS is True, update the global configuration. Otherwise, update the RAG configuration in the database for the user's knowledge base. """ - knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) - + form_data = wrapper.form_data + + knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): # Update the RAG configuration in the database rag_config = knowledge_base.rag_config @@ -783,14 +801,15 @@ async def update_rag_config( # Update only the provided fields in the rag_config for field, value in form_data.model_dump(exclude_unset=True).items(): if field == "web" and value is not None: - rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)} + rag_config["web"] = {**rag_config.get("web", {}), **value} else: rag_config[field] = value if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True): - request.app.state.rf[rag_config["reranking_model"]] = None - - Knowledges.update_knowledge_data_by_id( - id=knowledge_base.id, data={"rag_config": rag_config} + if rag_config.get("reranking_model"): + request.app.state.rf[rag_config["reranking_model"]] = None + + Knowledges.update_rag_config_by_id( + id=knowledge_base.id, rag_config=rag_config ) return rag_config @@ -1090,6 +1109,7 @@ async def update_rag_config( "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, + "DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS } diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 4f7724f2e..90bac89b3 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -1,6 +1,6 @@ import { RETRIEVAL_API_BASE_URL } from '$lib/constants'; -export const getRAGConfig = async (token: string, collectionForm?: CollectionNameForm) => { +export const getRAGConfig = async (token: string, collectionForm?: CollectionForm) => { let error = null; const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config`, { @@ -52,21 +52,68 @@ type YoutubeConfigForm = { proxy_url: string; }; +type WebConfigForm = { + ENABLE_WEB_SEARCH?: boolean; + WEB_SEARCH_ENGINE?: string; + WEB_SEARCH_TRUST_ENV?: boolean; + WEB_SEARCH_RESULT_COUNT?: number; + WEB_SEARCH_CONCURRENT_REQUESTS?: number; + WEB_SEARCH_DOMAIN_FILTER_LIST?: string[]; + BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL?: boolean; + SEARXNG_QUERY_URL?: string; + YACY_QUERY_URL?: string; + YACY_USERNAME?: string; + YACY_PASSWORD?: string; + GOOGLE_PSE_API_KEY?: string; + GOOGLE_PSE_ENGINE_ID?: string; + BRAVE_SEARCH_API_KEY?: string; + KAGI_SEARCH_API_KEY?: string; + MOJEEK_SEARCH_API_KEY?: string; + BOCHA_SEARCH_API_KEY?: string; + SERPSTACK_API_KEY?: string; + SERPSTACK_HTTPS?: boolean; + SERPER_API_KEY?: string; + SERPLY_API_KEY?: string; + TAVILY_API_KEY?: string; + SEARCHAPI_API_KEY?: string; + SEARCHAPI_ENGINE?: string; + SERPAPI_API_KEY?: string; + SERPAPI_ENGINE?: string; + JINA_API_KEY?: string; + BING_SEARCH_V7_ENDPOINT?: string; + BING_SEARCH_V7_SUBSCRIPTION_KEY?: string; + EXA_API_KEY?: string; + PERPLEXITY_API_KEY?: string; + SOUGOU_API_SID?: string; + SOUGOU_API_SK?: string; + WEB_LOADER_ENGINE?: string; + ENABLE_WEB_LOADER_SSL_VERIFICATION?: boolean; + PLAYWRIGHT_WS_URL?: string; + PLAYWRIGHT_TIMEOUT?: number; + FIRECRAWL_API_KEY?: string; + FIRECRAWL_API_BASE_URL?: string; + TAVILY_EXTRACT_DEPTH?: string; + EXTERNAL_WEB_SEARCH_URL?: string; + EXTERNAL_WEB_SEARCH_API_KEY?: string; + EXTERNAL_WEB_LOADER_URL?: string; + EXTERNAL_WEB_LOADER_API_KEY?: string; + YOUTUBE_LOADER_LANGUAGE?: string[]; + YOUTUBE_LOADER_PROXY_URL?: string; + YOUTUBE_LOADER_TRANSLATION?: string; +}; type RAGConfigForm = { PDF_EXTRACT_IMAGES?: boolean; ENABLE_GOOGLE_DRIVE_INTEGRATION?: boolean; ENABLE_ONEDRIVE_INTEGRATION?: boolean; chunk?: ChunkConfigForm; content_extraction?: ContentExtractConfigForm; - web_loader_ssl_verification?: boolean; + web?: WebConfigForm; youtube?: YoutubeConfigForm; + knowledge_id?: string; }; -type CollectionNameForm = { - collection_name: string; -}; -export const updateRAGConfig = async (token: string, payload: RAGConfigForm, collectionForm?: CollectionNameForm) => { +export const updateRAGConfig = async (token: string, form_data: RAGConfigForm) => { let error = null; const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config/update`, { @@ -76,9 +123,8 @@ export const updateRAGConfig = async (token: string, payload: RAGConfigForm, col Authorization: `Bearer ${token}` }, body: JSON.stringify({ - ...payload, - ...(collectionForm ? { collectionForm: collectionForm } : {}) - }) + form_data + }) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -160,7 +206,7 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings return res; }; -export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionNameForm) => { +export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionForm) => { let error = null; const res = await fetch(`${RETRIEVAL_API_BASE_URL}/embedding`, { @@ -200,7 +246,7 @@ type EmbeddingModelUpdateForm = { embedding_engine: string; embedding_model: string; embedding_batch_size?: number; - collection_name?: string; + knowledge_id?: string; }; export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { @@ -233,7 +279,7 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod return res; }; -export const getRerankingConfig = async (token: string, collectionForm?: CollectionNameForm) => { +export const getRerankingConfig = async (token: string, collectionForm?: CollectionForm) => { let error = null; const res = await fetch(`${RETRIEVAL_API_BASE_URL}/reranking`, { @@ -265,7 +311,7 @@ export const getRerankingConfig = async (token: string, collectionForm?: Collect type RerankingModelUpdateForm = { reranking_model: string; - collection_name?: string; + knowledge_id?: string; };