FIX: adjusted to send knowledge base id for saving docs - fixed update_reranking_config to handle knowledge config by collection name

This commit is contained in:
Maytown 2025-05-06 12:50:08 +02:00
parent 9233f1f848
commit 49e4375263

View File

@ -371,7 +371,7 @@ class RerankingModelUpdateForm(BaseModel):
@router.post("/reranking/update") @router.post("/reranking/update")
async def update_reranking_config( async def update_reranking_config(
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) request: Request, form_data: RerankingModelUpdateForm, collectionForm: CollectionNameForm, user=Depends(get_admin_user)
): ):
""" """
Update the reranking model configuration. Update the reranking model configuration.
@ -379,7 +379,7 @@ async def update_reranking_config(
Otherwise, update the RAG configuration in the database for the user's knowledge base. Otherwise, update the RAG configuration in the database for the user's knowledge base.
""" """
try: try:
knowledge_base = Knowledges.get_knowledge_base_by_user_id(user.id) knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.data.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database # Update the RAG configuration in the database
@ -1027,6 +1027,7 @@ def save_docs_to_vector_db(
split: bool = True, split: bool = True,
add: bool = False, add: bool = False,
user=None, user=None,
knowledge_id: Optional[str] = None,
) -> bool: ) -> bool:
def _get_docs_info(docs: list[Document]) -> str: def _get_docs_info(docs: list[Document]) -> str:
docs_info = set() docs_info = set()
@ -1049,9 +1050,9 @@ def save_docs_to_vector_db(
) )
# Retrieve the knowledge base using the collection_name # Retrieve the knowledge base using the collection_name
knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name) knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
if not knowledge_base: if not knowledge_base:
raise ValueError(f"Knowledge base not found for collection: {collection_name}") raise ValueError(f"Knowledge base not found for collection: {knowledge_base}")
# Retrieve the RAG configuration # Retrieve the RAG configuration
rag_config = {} rag_config = {}
@ -1204,6 +1205,7 @@ class ProcessFileForm(BaseModel):
def process_file( def process_file(
request: Request, request: Request,
form_data: ProcessFileForm, form_data: ProcessFileForm,
knowledge_id: Optional[str] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
try: try:
@ -1215,9 +1217,9 @@ def process_file(
collection_name = f"file-{file.id}" collection_name = f"file-{file.id}"
# Retrieve the knowledge base using the collection name # Retrieve the knowledge base using the collection name
knowledge_base = Knowledges.get_knowledge_base_by_collection_name(collection_name) knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
if not knowledge_base: if not knowledge_base:
raise ValueError(f"Knowledge base not found for collection: {collection_name}") raise ValueError(f"Knowledge base not found for collection: {knowledge_base}")
# Retrieve the RAG configuration # Retrieve the RAG configuration
rag_config = {} rag_config = {}
@ -1372,6 +1374,7 @@ def process_file(
}, },
add=(True if form_data.collection_name else False), add=(True if form_data.collection_name else False),
user=user, user=user,
knowledge_id=knowledge_id,
) )
if result: if result:
@ -1423,6 +1426,7 @@ def process_text(
request: Request, request: Request,
form_data: ProcessTextForm, form_data: ProcessTextForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
knowledge_id: Optional[str] = None,
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name is None: if collection_name is None:
@ -1437,7 +1441,7 @@ def process_text(
text_content = form_data.content text_content = form_data.content
log.debug(f"text_content: {text_content}") log.debug(f"text_content: {text_content}")
result = save_docs_to_vector_db(request, docs, collection_name, user=user) result = save_docs_to_vector_db(request, docs, collection_name, user=user, knowledge_id=knowledge_id)
if result: if result:
return { return {
"status": True, "status": True,