Fix: adjusted to handle both default and individual rag settings

This commit is contained in:
Maytown 2025-05-14 17:33:11 +02:00
parent 1ae3873c55
commit ba54452ab1
2 changed files with 107 additions and 41 deletions

View File

@ -190,6 +190,8 @@ class ProcessUrlForm(CollectionNameForm):
class SearchForm(BaseModel): class SearchForm(BaseModel):
query: str query: str
class CollectionForm(BaseModel):
knowledge_id: Optional[str] = None
@router.get("/") @router.get("/")
async def get_status(request: Request): async def get_status(request: Request):
@ -206,13 +208,15 @@ async def get_status(request: Request):
@router.post("/embedding") @router.post("/embedding")
async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)): async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
""" """
Retrieve the embedding configuration. Retrieve the embedding configuration.
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings. If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
Otherwise, return the embedding configuration stored in the database. Otherwise, return the embedding configuration stored in the database.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database # Return the embedding configuration from the database
rag_config = knowledge_base.rag_config rag_config = knowledge_base.rag_config
@ -249,13 +253,15 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
@router.post("/reranking") @router.post("/reranking")
async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)): async def get_reranking_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
""" """
Retrieve the reranking configuration. Retrieve the reranking configuration.
If DEFAULT_RAG_SETTINGS is True, return the default reranking settings. If DEFAULT_RAG_SETTINGS is True, return the default reranking settings.
Otherwise, return the reranking configuration stored in the database. Otherwise, return the reranking configuration stored in the database.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the reranking configuration from the database # Return the reranking configuration from the database
rag_config = knowledge_base.rag_config rag_config = knowledge_base.rag_config
@ -287,7 +293,7 @@ class EmbeddingModelUpdateForm(BaseModel):
embedding_engine: str embedding_engine: str
embedding_model: str embedding_model: str
embedding_batch_size: Optional[int] = 1 embedding_batch_size: Optional[int] = 1
collection_name: Optional[str] = None knowledge_id: Optional[str] = None
@router.post("/embedding/update") @router.post("/embedding/update")
@ -300,7 +306,7 @@ async def update_embedding_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base. Otherwise, update the RAG configuration in the database for the user's knowledge base.
""" """
try: try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database # Update the RAG configuration in the database
rag_config = knowledge_base.rag_config rag_config = knowledge_base.rag_config
@ -312,13 +318,12 @@ async def update_embedding_config(
rag_config["embedding_model"] = form_data.embedding_model rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size rag_config["embedding_batch_size"] = form_data.embedding_batch_size
if form_data.openai_config is not None:
rag_config["openai_config"] = { rag_config["openai_config"] = {
"url": form_data.openai_config.url, "url": form_data.openai_config.url,
"key": form_data.openai_config.key, "key": form_data.openai_config.key,
} }
if form_data.ollama_config is not None:
rag_config["ollama_config"] = { rag_config["ollama_config"] = {
"url": form_data.ollama_config.url, "url": form_data.ollama_config.url,
"key": form_data.ollama_config.key, "key": form_data.ollama_config.key,
@ -348,8 +353,8 @@ async def update_embedding_config(
) )
# Save the updated configuration to the database # Save the updated configuration to the database
Knowledges.update_knowledge_data_by_id( Knowledges.update_rag_config_by_id(
id=form_data.collection_name, data={"rag_config": rag_config} id=form_data.knowledge_id, rag_config=rag_config
) )
return { return {
@ -428,7 +433,7 @@ async def update_embedding_config(
class RerankingModelUpdateForm(BaseModel): class RerankingModelUpdateForm(BaseModel):
reranking_model: str reranking_model: str
collection_name: Optional[str] knowledge_id: Optional[str] = None
@router.post("/reranking/update") @router.post("/reranking/update")
async def update_reranking_config( async def update_reranking_config(
@ -440,16 +445,19 @@ async def update_reranking_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base. Otherwise, update the RAG configuration in the database for the user's knowledge base.
""" """
try: try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database # Update the RAG configuration in the database
rag_config = knowledge_base.rag_config rag_config = knowledge_base.rag_config
log.info( log.info(
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}" f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}"
) )
rag_config["reranking_model"] = form_data.reranking_model rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None
Knowledges.update_knowledge_data_by_id( Knowledges.update_rag_config_by_id(
id=knowledge_base.id, data={"rag_config": rag_config} id=form_data.knowledge_id, rag_config=rag_config
) )
try: try:
if not request.app.state.rf.get(rag_config["reranking_model"]): if not request.app.state.rf.get(rag_config["reranking_model"]):
@ -500,13 +508,15 @@ async def update_reranking_config(
@router.post("/config") @router.post("/config")
async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)): async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)):
""" """
Retrieve the full RAG configuration. Retrieve the full RAG configuration.
If DEFAULT_RAG_SETTINGS is True, return the default settings. If DEFAULT_RAG_SETTINGS is True, return the default settings.
Otherwise, return the RAG configuration stored in the database. Otherwise, return the RAG configuration stored in the database.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database # Return the RAG configuration from the database
rag_config = knowledge_base.rag_config rag_config = knowledge_base.rag_config
@ -764,18 +774,26 @@ class ConfigForm(BaseModel):
# Web search settings # Web search settings
web: Optional[WebConfig] = None web: Optional[WebConfig] = None
# knowledge base ID
knowledge_id: Optional[str] = None
class ConfigFormWrapper(BaseModel):
form_data: ConfigForm
@router.post("/config/update") @router.post("/config/update")
async def update_rag_config( async def update_rag_config(
request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user) request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user)
): ):
""" """
Update the RAG configuration. Update the RAG configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration. If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base. Otherwise, update the RAG configuration in the database for the user's knowledge base.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) form_data = wrapper.form_data
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database # Update the RAG configuration in the database
rag_config = knowledge_base.rag_config rag_config = knowledge_base.rag_config
@ -783,14 +801,15 @@ async def update_rag_config(
# Update only the provided fields in the rag_config # Update only the provided fields in the rag_config
for field, value in form_data.model_dump(exclude_unset=True).items(): for field, value in form_data.model_dump(exclude_unset=True).items():
if field == "web" and value is not None: if field == "web" and value is not None:
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)} rag_config["web"] = {**rag_config.get("web", {}), **value}
else: else:
rag_config[field] = value rag_config[field] = value
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True): if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
if rag_config.get("reranking_model"):
request.app.state.rf[rag_config["reranking_model"]] = None request.app.state.rf[rag_config["reranking_model"]] = None
Knowledges.update_knowledge_data_by_id( Knowledges.update_rag_config_by_id(
id=knowledge_base.id, data={"rag_config": rag_config} id=knowledge_base.id, rag_config=rag_config
) )
return rag_config return rag_config
@ -1090,6 +1109,7 @@ async def update_rag_config(
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
} }

View File

@ -1,6 +1,6 @@
import { RETRIEVAL_API_BASE_URL } from '$lib/constants'; import { RETRIEVAL_API_BASE_URL } from '$lib/constants';
export const getRAGConfig = async (token: string, collectionForm?: CollectionNameForm) => { export const getRAGConfig = async (token: string, collectionForm?: CollectionForm) => {
let error = null; let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config`, { const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config`, {
@ -52,21 +52,68 @@ type YoutubeConfigForm = {
proxy_url: string; proxy_url: string;
}; };
type WebConfigForm = {
ENABLE_WEB_SEARCH?: boolean;
WEB_SEARCH_ENGINE?: string;
WEB_SEARCH_TRUST_ENV?: boolean;
WEB_SEARCH_RESULT_COUNT?: number;
WEB_SEARCH_CONCURRENT_REQUESTS?: number;
WEB_SEARCH_DOMAIN_FILTER_LIST?: string[];
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL?: boolean;
SEARXNG_QUERY_URL?: string;
YACY_QUERY_URL?: string;
YACY_USERNAME?: string;
YACY_PASSWORD?: string;
GOOGLE_PSE_API_KEY?: string;
GOOGLE_PSE_ENGINE_ID?: string;
BRAVE_SEARCH_API_KEY?: string;
KAGI_SEARCH_API_KEY?: string;
MOJEEK_SEARCH_API_KEY?: string;
BOCHA_SEARCH_API_KEY?: string;
SERPSTACK_API_KEY?: string;
SERPSTACK_HTTPS?: boolean;
SERPER_API_KEY?: string;
SERPLY_API_KEY?: string;
TAVILY_API_KEY?: string;
SEARCHAPI_API_KEY?: string;
SEARCHAPI_ENGINE?: string;
SERPAPI_API_KEY?: string;
SERPAPI_ENGINE?: string;
JINA_API_KEY?: string;
BING_SEARCH_V7_ENDPOINT?: string;
BING_SEARCH_V7_SUBSCRIPTION_KEY?: string;
EXA_API_KEY?: string;
PERPLEXITY_API_KEY?: string;
SOUGOU_API_SID?: string;
SOUGOU_API_SK?: string;
WEB_LOADER_ENGINE?: string;
ENABLE_WEB_LOADER_SSL_VERIFICATION?: boolean;
PLAYWRIGHT_WS_URL?: string;
PLAYWRIGHT_TIMEOUT?: number;
FIRECRAWL_API_KEY?: string;
FIRECRAWL_API_BASE_URL?: string;
TAVILY_EXTRACT_DEPTH?: string;
EXTERNAL_WEB_SEARCH_URL?: string;
EXTERNAL_WEB_SEARCH_API_KEY?: string;
EXTERNAL_WEB_LOADER_URL?: string;
EXTERNAL_WEB_LOADER_API_KEY?: string;
YOUTUBE_LOADER_LANGUAGE?: string[];
YOUTUBE_LOADER_PROXY_URL?: string;
YOUTUBE_LOADER_TRANSLATION?: string;
};
type RAGConfigForm = { type RAGConfigForm = {
PDF_EXTRACT_IMAGES?: boolean; PDF_EXTRACT_IMAGES?: boolean;
ENABLE_GOOGLE_DRIVE_INTEGRATION?: boolean; ENABLE_GOOGLE_DRIVE_INTEGRATION?: boolean;
ENABLE_ONEDRIVE_INTEGRATION?: boolean; ENABLE_ONEDRIVE_INTEGRATION?: boolean;
chunk?: ChunkConfigForm; chunk?: ChunkConfigForm;
content_extraction?: ContentExtractConfigForm; content_extraction?: ContentExtractConfigForm;
web_loader_ssl_verification?: boolean; web?: WebConfigForm;
youtube?: YoutubeConfigForm; youtube?: YoutubeConfigForm;
knowledge_id?: string;
}; };
type CollectionNameForm = {
collection_name: string;
};
export const updateRAGConfig = async (token: string, payload: RAGConfigForm, collectionForm?: CollectionNameForm) => { export const updateRAGConfig = async (token: string, form_data: RAGConfigForm) => {
let error = null; let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config/update`, { const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config/update`, {
@ -76,8 +123,7 @@ export const updateRAGConfig = async (token: string, payload: RAGConfigForm, col
Authorization: `Bearer ${token}` Authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify({
...payload, form_data
...(collectionForm ? { collectionForm: collectionForm } : {})
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -160,7 +206,7 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings
return res; return res;
}; };
export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionNameForm) => { export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionForm) => {
let error = null; let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/embedding`, { const res = await fetch(`${RETRIEVAL_API_BASE_URL}/embedding`, {
@ -200,7 +246,7 @@ type EmbeddingModelUpdateForm = {
embedding_engine: string; embedding_engine: string;
embedding_model: string; embedding_model: string;
embedding_batch_size?: number; embedding_batch_size?: number;
collection_name?: string; knowledge_id?: string;
}; };
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
@ -233,7 +279,7 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
return res; return res;
}; };
export const getRerankingConfig = async (token: string, collectionForm?: CollectionNameForm) => { export const getRerankingConfig = async (token: string, collectionForm?: CollectionForm) => {
let error = null; let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/reranking`, { const res = await fetch(`${RETRIEVAL_API_BASE_URL}/reranking`, {
@ -265,7 +311,7 @@ export const getRerankingConfig = async (token: string, collectionForm?: Collect
type RerankingModelUpdateForm = { type RerankingModelUpdateForm = {
reranking_model: string; reranking_model: string;
collection_name?: string; knowledge_id?: string;
}; };