Fix: Added compatibility of azure openai for individual rag config - fixed query doc handler and query collection handler to handle individual rag embedding functions

This commit is contained in:
weberm1 2025-06-06 12:05:19 +02:00
parent 644cfa139b
commit 4c19aaaa64

View File

@ -332,16 +332,26 @@ async def update_embedding_config(
rag_config["embedding_model"] = form_data.embedding_model rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size rag_config["embedding_batch_size"] = form_data.embedding_batch_size
# Update OpenAI, Ollama, and Azure OpenAI configurations if provided
if form_data.openai_config is not None:
rag_config["openai_config"] = { rag_config["openai_config"] = {
"url": form_data.openai_config.url, "url": form_data.openai_config.url,
"key": form_data.openai_config.key, "key": form_data.openai_config.key,
} }
if form_data.ollama_config is not None:
rag_config["ollama_config"] = { rag_config["ollama_config"] = {
"url": form_data.ollama_config.url, "url": form_data.ollama_config.url,
"key": form_data.ollama_config.key, "key": form_data.ollama_config.key,
} }
if form_data.azure_openai_config is not None:
rag_config["azure_openai_config"] = {
"url": form_data.azure_openai_config.url,
"key": form_data.azure_openai_config.key,
"version": form_data.azure_openai_config.version,
}
# Update the embedding function # Update the embedding function
if not rag_config["embedding_model"] in request.app.state.ef: if not rag_config["embedding_model"] in request.app.state.ef:
request.app.state.ef[rag_config["embedding_model"]] = get_ef( request.app.state.ef[rag_config["embedding_model"]] = get_ef(
@ -363,9 +373,19 @@ async def update_embedding_config(
if rag_config["embedding_engine"] == "openai" if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["key"] else rag_config["ollama_config"]["key"]
), ),
rag_config["embedding_batch_size"] rag_config["embedding_batch_size"],
azure_api_version=(
rag_config["azure_openai_config"]["version"]
if rag_config["embedding_engine"] == "azure_openai"
else None
)
) )
# add model to state for reloading on startup # add model to state for reloading on startup
if rag_config["embedding_engine"] == "azure_openai":
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(
{rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")}
)
else:
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"]) request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable reranking models # add model to state for selectable reranking models
@ -387,9 +407,9 @@ async def update_embedding_config(
"embedding_batch_size": rag_config["embedding_batch_size"], "embedding_batch_size": rag_config["embedding_batch_size"],
"openai_config": rag_config.get("openai_config", {}), "openai_config": rag_config.get("openai_config", {}),
"ollama_config": rag_config.get("ollama_config", {}), "ollama_config": rag_config.get("ollama_config", {}),
"azure_openai_config": rag_config.get("azure_openai_config", {}),
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"], "DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"], "LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
"message": "Embedding configuration updated in the database.",
} }
else: else:
# Update the global configuration # Update the global configuration
@ -417,6 +437,9 @@ async def update_embedding_config(
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if request.app.state.config.RAG_EMBEDDING_ENGINE in [ if request.app.state.config.RAG_EMBEDDING_ENGINE in [
"ollama", "ollama",
"openai", "openai",
@ -438,50 +461,6 @@ async def update_embedding_config(
form_data.ollama_config.key form_data.ollama_config.key
) )
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None: if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url form_data.azure_openai_config.url
@ -534,6 +513,11 @@ async def update_embedding_config(
), ),
) )
# add model to state for reloading on startup # add model to state for reloading on startup
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai":
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(
{request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION}
)
else:
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL) request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save() request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable embedding models # add model to state for selectable embedding models
@ -1541,9 +1525,9 @@ def save_docs_to_vector_db(
log.info(f"adding to collection {collection_name}") log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function( embedding_function = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE, embedding_engine,
request.app.state.config.RAG_EMBEDDING_MODEL, embedding_model,
request.app.state.ef, request.app.state.ef[embedding_model],
( (
openai_api_base_url openai_api_base_url
if embedding_engine == "openai" if embedding_engine == "openai"
@ -1554,9 +1538,9 @@ def save_docs_to_vector_db(
) )
), ),
( (
request.app.state.config.RAG_OPENAI_API_KEY openai_api_key
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if embedding_engine == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY else ollama_api_key
), ),
embedding_batch_size, embedding_batch_size,
azure_api_version=( azure_api_version=(
@ -2370,7 +2354,24 @@ def query_doc_handler(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
try: try:
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: # Try to get individual rag config for this collection
rag_config = {}
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use config from rag_config if present, else fallback to global config
enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K)
top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None)
if hybrid_bm25_weight is None:
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT)
if enable_hybrid:
collection_results = {} collection_results = {}
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
collection_name=form_data.collection_name collection_name=form_data.collection_name
@ -2379,32 +2380,23 @@ def query_doc_handler(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name], collection_result=collection_results[form_data.collection_name],
query=form_data.query, query=form_data.query,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
query, prefix=prefix, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=top_k,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], reranking_function=request.app.state.rf[reranking_model],
k_reranker=form_data.k_reranker k_reranker=top_k_reranker,
or request.app.state.config.TOP_K_RERANKER, r=relevance_threshold,
r=( hybrid_bm25_weight=hybrid_bm25_weight,
form_data.r
if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD
),
hybrid_bm25_weight=(
form_data.hybrid_bm25_weight
if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT
),
user=user, user=user,
) )
else: else:
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model](
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=top_k,
user=user, user=user,
) )
except Exception as e: except Exception as e:
@ -2436,11 +2428,10 @@ def query_collection_handler(
return query_collection_with_hybrid_search( return query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( user=user,
query, prefix=prefix, user=user ef=request.app.state.EMBEDDING_FUNCTION,
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL], reranking_function=request.app.state.rf,
k_reranker=form_data.k_reranker k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER, or request.app.state.config.TOP_K_RERANKER,
r=( r=(
@ -2453,14 +2444,16 @@ def query_collection_handler(
if form_data.hybrid_bm25_weight if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT else request.app.state.config.HYBRID_BM25_WEIGHT
), ),
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
reranking_model=request.app.state.config.RAG_RERANKING_MODEL,
) )
else: else:
return query_collection( return query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]( user=user,
query, prefix=prefix, user=user ef=request.app.state.EMBEDDING_FUNCTION,
), embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
) )