Refactoring: Adjusted to newly added rag_config column

This commit is contained in:
Maytown 2025-05-12 12:47:55 +02:00
parent 0b9ed1ee42
commit 0d2eefd83d

View File

@ -213,9 +213,9 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
Otherwise, return the embedding configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database
rag_config = knowledge_base.data.get("rag_config", {})
rag_config = knowledge_base.rag_config
return {
"status": True,
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
@ -256,9 +256,9 @@ async def get_reranking_config(request: Request, collectionForm: CollectionNameF
Otherwise, return the reranking configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the reranking configuration from the database
rag_config = knowledge_base.data.get("rag_config", {})
rag_config = knowledge_base.rag_config
return {
"status": True,
"reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL),
@ -287,77 +287,137 @@ class EmbeddingModelUpdateForm(BaseModel):
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
collection_name: Optional[str] = None
@router.post("/embedding/update")
async def update_embedding_config(
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
# TODO Update for individual rag config
log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
"""
Update the embedding model configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
log.info(
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
)
# Update embedding-related fields
rag_config["embedding_engine"] = form_data.embedding_engine
rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
)
request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
if form_data.ollama_config is not None:
request.app.state.config.RAG_OLLAMA_BASE_URL = (
form_data.ollama_config.url
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
# Update the embedding function
if not request.app.state.ef.get("embedding_model"):
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
rag_config["embedding_engine"],
rag_config["embedding_model"],
)
request.app.state.config.RAG_OLLAMA_API_KEY = (
form_data.ollama_config.key
request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function(
rag_config["embedding_engine"],
rag_config["embedding_model"],
request.app.state.ef[rag_config["embedding_model"]],
(
rag_config["openai_config"]["url"]
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["url"]
),
(
rag_config["openai_config"]["key"]
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["key"]
),
rag_config["embedding_batch_size"]
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size
# Save the updated configuration to the database
Knowledges.update_knowledge_data_by_id(
id=form_data.collection_name, data={"rag_config": rag_config}
)
request.app.state.ef = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
)
return {
"status": True,
"embedding_engine": rag_config["embedding_engine"],
"embedding_model": rag_config["embedding_model"],
"embedding_batch_size": rag_config["embedding_batch_size"],
"openai_config": rag_config.get("openai_config", {}),
"ollama_config": rag_config.get("ollama_config", {}),
"message": "Embedding configuration updated in the database.",
}
else:
# Update the global configuration
log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef,
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
request.app.state.config.RAG_OPENAI_API_KEY = form_data.openai_config.key
return {
"status": True,
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
"key": request.app.state.config.RAG_OPENAI_API_KEY,
},
"ollama_config": {
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
}
if form_data.ollama_config is not None:
request.app.state.config.RAG_OLLAMA_BASE_URL = form_data.ollama_config.url
request.app.state.config.RAG_OLLAMA_API_KEY = form_data.ollama_config.key
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
# Update the embedding function
if not request.app.state.ef.get(form_data.embedding_model):
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
)
request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL],
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
"key": request.app.state.config.RAG_OPENAI_API_KEY,
},
"ollama_config": {
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"message": "Embedding configuration updated globally.",
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
raise HTTPException(
@ -381,14 +441,27 @@ async def update_reranking_config(
"""
try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
# TODO UPdate reranking accoridngly
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.data.get("rag_config", {})
rag_config = knowledge_base.rag_config
log.info(
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
)
rag_config["reranking_model"] = form_data.reranking_model
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
)
try:
if not request.app.state.rf.get(rag_config["reranking_model"]):
request.app.state.rf[rag_config["reranking_model"]] = get_rf(
rag_config["reranking_model"],
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
rag_config["ENABLE_RAG_HYBRID_SEARCH"] = False
return {
"status": True,
"reranking_model": rag_config["reranking_model"],
@ -402,10 +475,13 @@ async def update_reranking_config(
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
try:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_MODEL,
True,
)
if request.app.state.rf.get(form_data.reranking_model):
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model]
else:
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
request.app.state.config.RAG_RERANKING_MODEL,
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@ -431,9 +507,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
Otherwise, return the RAG configuration stored in the database.
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database
rag_config = knowledge_base.data.get("rag_config", {})
rag_config = knowledge_base.rag_config
web_config = rag_config.get("web", {})
return {
"status": True,
@ -700,9 +776,9 @@ async def update_rag_config(
"""
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.data.get("rag_config", {})
rag_config = knowledge_base.rag_config
# Update only the provided fields in the rag_config
for field, value in form_data.model_dump(exclude_unset=True).items():
@ -710,7 +786,9 @@ async def update_rag_config(
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
else:
rag_config[field] = value
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
request.app.state.rf[rag_config["reranking_model"]] = None
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config}
)
@ -748,7 +826,7 @@ async def update_rag_config(
)
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None
request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER
@ -1052,15 +1130,15 @@ def save_docs_to_vector_db(
log.info(
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
)
rag_config = {}
# Retrieve the knowledge base using the collection_name
if knowledge_id:
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
# Retrieve the RAG configuration
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {})
print("RAG CONFIG: ", rag_config)
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use knowledge-base-specific or default configurations
text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER)
chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE)
@ -1072,7 +1150,7 @@ def save_docs_to_vector_db(
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
result = VECTOR_DB_CLIENT.query(
@ -1156,7 +1234,7 @@ def save_docs_to_vector_db(
embedding_function = get_embedding_function(
embedding_engine,
embedding_model,
request.app.state.ef,
request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
(
openai_api_base_url
if embedding_engine == "openai"
@ -1224,16 +1302,16 @@ def process_file(
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
# Retrieve the RAG configuration
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {})
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db
elif form_data.knowledge_id:
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
# Retrieve the RAG configuration
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {})
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use knowledge-base-specific or default configurations
content_extraction_engine = rag_config.get(
@ -1906,7 +1984,7 @@ def query_doc_handler(
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
@ -1957,7 +2035,7 @@ def query_collection_handler(
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(