Fix: adjusted to handle both default and individual rag settings

This commit is contained in:
Maytown 2025-05-14 17:33:11 +02:00
parent 1ae3873c55
commit ba54452ab1
2 changed files with 107 additions and 41 deletions

View File

@ -190,6 +190,8 @@ class ProcessUrlForm(CollectionNameForm):
class SearchForm(BaseModel):
query: str
class CollectionForm(BaseModel):
knowledge_id: Optional[str] = None
@router.get("/")
async def get_status(request: Request):
@ -206,13 +208,15 @@ async def get_status(request: Request):
@router.post("/embedding")
async def get_embedding_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
"""
Retrieve the embedding configuration.
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
Otherwise, return the embedding configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database
rag_config = knowledge_base.rag_config
@ -249,13 +253,15 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
@router.post("/reranking")
async def get_reranking_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_verified_user)):
async def get_reranking_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
"""
Retrieve the reranking configuration.
If DEFAULT_RAG_SETTINGS is True, return the default reranking settings.
Otherwise, return the reranking configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the reranking configuration from the database
rag_config = knowledge_base.rag_config
@ -287,7 +293,7 @@ class EmbeddingModelUpdateForm(BaseModel):
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
collection_name: Optional[str] = None
knowledge_id: Optional[str] = None
@router.post("/embedding/update")
@ -300,7 +306,7 @@ async def update_embedding_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
@ -312,14 +318,13 @@ async def update_embedding_config(
rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
if form_data.openai_config is not None:
rag_config["openai_config"] = {
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
if form_data.ollama_config is not None:
rag_config["ollama_config"] = {
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
@ -348,8 +353,8 @@ async def update_embedding_config(
)
# Save the updated configuration to the database
Knowledges.update_knowledge_data_by_id(
id=form_data.collection_name, data={"rag_config": rag_config}
Knowledges.update_rag_config_by_id(
id=form_data.knowledge_id, rag_config=rag_config
)
return {
@ -428,7 +433,7 @@ async def update_embedding_config(
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
collection_name: Optional[str]
knowledge_id: Optional[str] = None
@router.post("/reranking/update")
async def update_reranking_config(
@ -440,16 +445,19 @@ async def update_reranking_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
log.info(
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
f"Updating reranking model: {rag_config.get('reranking_model')} to {form_data.reranking_model}"
)
rag_config["reranking_model"] = form_data.reranking_model
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
rag_config["reranking_model"] = form_data.reranking_model if form_data.reranking_model else None
Knowledges.update_rag_config_by_id(
id=form_data.knowledge_id, rag_config=rag_config
)
try:
if not request.app.state.rf.get(rag_config["reranking_model"]):
@ -500,13 +508,15 @@ async def update_reranking_config(
@router.post("/config")
async def get_rag_config(request: Request, collectionForm: CollectionNameForm, user=Depends(get_admin_user)):
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_admin_user)):
"""
Retrieve the full RAG configuration.
If DEFAULT_RAG_SETTINGS is True, return the default settings.
Otherwise, return the RAG configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database
rag_config = knowledge_base.rag_config
@ -764,18 +774,26 @@ class ConfigForm(BaseModel):
# Web search settings
web: Optional[WebConfig] = None
# knowledge base ID
knowledge_id: Optional[str] = None
class ConfigFormWrapper(BaseModel):
form_data: ConfigForm
@router.post("/config/update")
async def update_rag_config(
request: Request, form_data: ConfigForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user)
request: Request, wrapper: ConfigFormWrapper, user=Depends(get_admin_user)
):
"""
Update the RAG configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
form_data = wrapper.form_data
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
@ -783,14 +801,15 @@ async def update_rag_config(
# Update only the provided fields in the rag_config
for field, value in form_data.model_dump(exclude_unset=True).items():
if field == "web" and value is not None:
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
rag_config["web"] = {**rag_config.get("web", {}), **value}
else:
rag_config[field] = value
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
request.app.state.rf[rag_config["reranking_model"]] = None
if rag_config.get("reranking_model"):
request.app.state.rf[rag_config["reranking_model"]] = None
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
Knowledges.update_rag_config_by_id(
id=knowledge_base.id, rag_config=rag_config
)
return rag_config
@ -1090,6 +1109,7 @@ async def update_rag_config(
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
}

View File

@ -1,6 +1,6 @@
import { RETRIEVAL_API_BASE_URL } from '$lib/constants';
export const getRAGConfig = async (token: string, collectionForm?: CollectionNameForm) => {
export const getRAGConfig = async (token: string, collectionForm?: CollectionForm) => {
let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config`, {
@ -52,21 +52,68 @@ type YoutubeConfigForm = {
proxy_url: string;
};
type WebConfigForm = {
ENABLE_WEB_SEARCH?: boolean;
WEB_SEARCH_ENGINE?: string;
WEB_SEARCH_TRUST_ENV?: boolean;
WEB_SEARCH_RESULT_COUNT?: number;
WEB_SEARCH_CONCURRENT_REQUESTS?: number;
WEB_SEARCH_DOMAIN_FILTER_LIST?: string[];
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL?: boolean;
SEARXNG_QUERY_URL?: string;
YACY_QUERY_URL?: string;
YACY_USERNAME?: string;
YACY_PASSWORD?: string;
GOOGLE_PSE_API_KEY?: string;
GOOGLE_PSE_ENGINE_ID?: string;
BRAVE_SEARCH_API_KEY?: string;
KAGI_SEARCH_API_KEY?: string;
MOJEEK_SEARCH_API_KEY?: string;
BOCHA_SEARCH_API_KEY?: string;
SERPSTACK_API_KEY?: string;
SERPSTACK_HTTPS?: boolean;
SERPER_API_KEY?: string;
SERPLY_API_KEY?: string;
TAVILY_API_KEY?: string;
SEARCHAPI_API_KEY?: string;
SEARCHAPI_ENGINE?: string;
SERPAPI_API_KEY?: string;
SERPAPI_ENGINE?: string;
JINA_API_KEY?: string;
BING_SEARCH_V7_ENDPOINT?: string;
BING_SEARCH_V7_SUBSCRIPTION_KEY?: string;
EXA_API_KEY?: string;
PERPLEXITY_API_KEY?: string;
SOUGOU_API_SID?: string;
SOUGOU_API_SK?: string;
WEB_LOADER_ENGINE?: string;
ENABLE_WEB_LOADER_SSL_VERIFICATION?: boolean;
PLAYWRIGHT_WS_URL?: string;
PLAYWRIGHT_TIMEOUT?: number;
FIRECRAWL_API_KEY?: string;
FIRECRAWL_API_BASE_URL?: string;
TAVILY_EXTRACT_DEPTH?: string;
EXTERNAL_WEB_SEARCH_URL?: string;
EXTERNAL_WEB_SEARCH_API_KEY?: string;
EXTERNAL_WEB_LOADER_URL?: string;
EXTERNAL_WEB_LOADER_API_KEY?: string;
YOUTUBE_LOADER_LANGUAGE?: string[];
YOUTUBE_LOADER_PROXY_URL?: string;
YOUTUBE_LOADER_TRANSLATION?: string;
};
type RAGConfigForm = {
PDF_EXTRACT_IMAGES?: boolean;
ENABLE_GOOGLE_DRIVE_INTEGRATION?: boolean;
ENABLE_ONEDRIVE_INTEGRATION?: boolean;
chunk?: ChunkConfigForm;
content_extraction?: ContentExtractConfigForm;
web_loader_ssl_verification?: boolean;
web?: WebConfigForm;
youtube?: YoutubeConfigForm;
knowledge_id?: string;
};
type CollectionNameForm = {
collection_name: string;
};
export const updateRAGConfig = async (token: string, payload: RAGConfigForm, collectionForm?: CollectionNameForm) => {
export const updateRAGConfig = async (token: string, form_data: RAGConfigForm) => {
let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config/update`, {
@ -76,9 +123,8 @@ export const updateRAGConfig = async (token: string, payload: RAGConfigForm, col
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
...payload,
...(collectionForm ? { collectionForm: collectionForm } : {})
})
form_data
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
@ -160,7 +206,7 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings
return res;
};
export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionNameForm) => {
export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionForm) => {
let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/embedding`, {
@ -200,7 +246,7 @@ type EmbeddingModelUpdateForm = {
embedding_engine: string;
embedding_model: string;
embedding_batch_size?: number;
collection_name?: string;
knowledge_id?: string;
};
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
@ -233,7 +279,7 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
return res;
};
export const getRerankingConfig = async (token: string, collectionForm?: CollectionNameForm) => {
export const getRerankingConfig = async (token: string, collectionForm?: CollectionForm) => {
let error = null;
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/reranking`, {
@ -265,7 +311,7 @@ export const getRerankingConfig = async (token: string, collectionForm?: Collect
type RerankingModelUpdateForm = {
reranking_model: string;
collection_name?: string;
knowledge_id?: string;
};