Refactoring: Adjusted to newly added rag_config column

This commit is contained in:
Maytown 2025-05-12 12:47:55 +02:00
parent 0b9ed1ee42
commit 0d2eefd83d

View File

@ -213,9 +213,9 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
Otherwise, return the embedding configuration stored in the database. Otherwise, return the embedding configuration stored in the database.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the embedding configuration from the database # Return the embedding configuration from the database
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
return { return {
"status": True, "status": True,
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE), "embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
@ -256,9 +256,9 @@ async def get_reranking_config(request: Request, collectionForm: CollectionNameF
Otherwise, return the reranking configuration stored in the database. Otherwise, return the reranking configuration stored in the database.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the reranking configuration from the database # Return the reranking configuration from the database
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
return { return {
"status": True, "status": True,
"reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL), "reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL),
@ -287,50 +287,109 @@ class EmbeddingModelUpdateForm(BaseModel):
embedding_engine: str embedding_engine: str
embedding_model: str embedding_model: str
embedding_batch_size: Optional[int] = 1 embedding_batch_size: Optional[int] = 1
collection_name: Optional[str] = None
@router.post("/embedding/update") @router.post("/embedding/update")
async def update_embedding_config( async def update_embedding_config(
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
# TODO Update for individual rag config """
Update the embedding model configuration.
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
Otherwise, update the RAG configuration in the database for the user's knowledge base.
"""
try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database
rag_config = knowledge_base.rag_config
log.info(
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
)
# Update embedding-related fields
rag_config["embedding_engine"] = form_data.embedding_engine
rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
if form_data.openai_config is not None:
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
if form_data.ollama_config is not None:
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
# Update the embedding function
if not request.app.state.ef.get("embedding_model"):
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
rag_config["embedding_engine"],
rag_config["embedding_model"],
)
request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function(
rag_config["embedding_engine"],
rag_config["embedding_model"],
request.app.state.ef[rag_config["embedding_model"]],
(
rag_config["openai_config"]["url"]
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["url"]
),
(
rag_config["openai_config"]["key"]
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["key"]
),
rag_config["embedding_batch_size"]
)
# Save the updated configuration to the database
Knowledges.update_knowledge_data_by_id(
id=form_data.collection_name, data={"rag_config": rag_config}
)
return {
"status": True,
"embedding_engine": rag_config["embedding_engine"],
"embedding_model": rag_config["embedding_model"],
"embedding_batch_size": rag_config["embedding_batch_size"],
"openai_config": rag_config.get("openai_config", {}),
"ollama_config": rag_config.get("ollama_config", {}),
"message": "Embedding configuration updated in the database.",
}
else:
# Update the global configuration
log.info( log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try:
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config is not None: if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = ( request.app.state.config.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
form_data.openai_config.url request.app.state.config.RAG_OPENAI_API_KEY = form_data.openai_config.key
)
request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
if form_data.ollama_config is not None: if form_data.ollama_config is not None:
request.app.state.config.RAG_OLLAMA_BASE_URL = ( request.app.state.config.RAG_OLLAMA_BASE_URL = form_data.ollama_config.url
form_data.ollama_config.url request.app.state.config.RAG_OLLAMA_API_KEY = form_data.ollama_config.key
)
request.app.state.config.RAG_OLLAMA_API_KEY = (
form_data.ollama_config.key
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
form_data.embedding_batch_size
)
request.app.state.ef = get_ef( # Update the embedding function
if not request.app.state.ef.get(form_data.embedding_model):
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
) )
request.app.state.EMBEDDING_FUNCTION = get_embedding_function( request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef, request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL],
( (
request.app.state.config.RAG_OPENAI_API_BASE_URL request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@ -357,6 +416,7 @@ async def update_embedding_config(
"url": request.app.state.config.RAG_OLLAMA_BASE_URL, "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY, "key": request.app.state.config.RAG_OLLAMA_API_KEY,
}, },
"message": "Embedding configuration updated globally.",
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating embedding model: {e}") log.exception(f"Problem updating embedding model: {e}")
@ -381,14 +441,27 @@ async def update_reranking_config(
""" """
try: try:
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name) knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
# TODO UPdate reranking accoridngly if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database # Update the RAG configuration in the database
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
log.info(
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
)
rag_config["reranking_model"] = form_data.reranking_model rag_config["reranking_model"] = form_data.reranking_model
Knowledges.update_knowledge_data_by_id( Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config} id=knowledge_base.id, data={"rag_config": rag_config}
) )
try:
if not request.app.state.rf.get(rag_config["reranking_model"]):
request.app.state.rf[rag_config["reranking_model"]] = get_rf(
rag_config["reranking_model"],
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
rag_config["ENABLE_RAG_HYBRID_SEARCH"] = False
return { return {
"status": True, "status": True,
"reranking_model": rag_config["reranking_model"], "reranking_model": rag_config["reranking_model"],
@ -402,7 +475,10 @@ async def update_reranking_config(
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
try: try:
request.app.state.rf = get_rf( if request.app.state.rf.get(form_data.reranking_model):
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model]
else:
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
request.app.state.config.RAG_RERANKING_MODEL, request.app.state.config.RAG_RERANKING_MODEL,
True, True,
) )
@ -431,9 +507,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
Otherwise, return the RAG configuration stored in the database. Otherwise, return the RAG configuration stored in the database.
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Return the RAG configuration from the database # Return the RAG configuration from the database
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
web_config = rag_config.get("web", {}) web_config = rag_config.get("web", {})
return { return {
"status": True, "status": True,
@ -700,9 +776,9 @@ async def update_rag_config(
""" """
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name) knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
# Update the RAG configuration in the database # Update the RAG configuration in the database
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
# Update only the provided fields in the rag_config # Update only the provided fields in the rag_config
for field, value in form_data.model_dump(exclude_unset=True).items(): for field, value in form_data.model_dump(exclude_unset=True).items():
@ -710,6 +786,8 @@ async def update_rag_config(
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)} rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
else: else:
rag_config[field] = value rag_config[field] = value
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
request.app.state.rf[rag_config["reranking_model"]] = None
Knowledges.update_knowledge_data_by_id( Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data={"rag_config": rag_config} id=knowledge_base.id, data={"rag_config": rag_config}
@ -748,7 +826,7 @@ async def update_rag_config(
) )
# Free up memory if hybrid search is disabled # Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None
request.app.state.config.TOP_K_RERANKER = ( request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER form_data.TOP_K_RERANKER
@ -1058,9 +1136,9 @@ def save_docs_to_vector_db(
if knowledge_id: if knowledge_id:
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id) knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
# Retrieve the RAG configuration # Retrieve the RAG configuration
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
print("RAG CONFIG: ", rag_config)
# Use knowledge-base-specific or default configurations # Use knowledge-base-specific or default configurations
text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER) text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER)
chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE) chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE)
@ -1156,7 +1234,7 @@ def save_docs_to_vector_db(
embedding_function = get_embedding_function( embedding_function = get_embedding_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
request.app.state.ef, request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
( (
openai_api_base_url openai_api_base_url
if embedding_engine == "openai" if embedding_engine == "openai"
@ -1224,16 +1302,16 @@ def process_file(
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name) knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
# Retrieve the RAG configuration # Retrieve the RAG configuration
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db
elif form_data.knowledge_id: elif form_data.knowledge_id:
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id) knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
# Retrieve the RAG configuration # Retrieve the RAG configuration
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True): if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.data.get("rag_config", {}) rag_config = knowledge_base.rag_config
# Use knowledge-base-specific or default configurations # Use knowledge-base-specific or default configurations
content_extraction_engine = rag_config.get( content_extraction_engine = rag_config.get(
@ -1906,7 +1984,7 @@ def query_doc_handler(
query, prefix=prefix, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER, or request.app.state.config.TOP_K_RERANKER,
r=( r=(
@ -1957,7 +2035,7 @@ def query_collection_handler(
query, prefix=prefix, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER, or request.app.state.config.TOP_K_RERANKER,
r=( r=(