Fix: Added compatibility of azure openai for individual rag config - fixed query doc handler and query collection handler to handle individual rag embedding functions

This commit is contained in:
weberm1 2025-06-06 12:05:19 +02:00
parent 644cfa139b
commit 4c19aaaa64

View File

@ -332,16 +332,26 @@ async def update_embedding_config(
rag_config["embedding_model"] = form_data.embedding_model
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
# Update OpenAI, Ollama, and Azure OpenAI configurations if provided
if form_data.openai_config is not None:
rag_config["openai_config"] = {
"url": form_data.openai_config.url,
"key": form_data.openai_config.key,
}
if form_data.ollama_config is not None:
rag_config["ollama_config"] = {
"url": form_data.ollama_config.url,
"key": form_data.ollama_config.key,
}
if form_data.azure_openai_config is not None:
rag_config["azure_openai_config"] = {
"url": form_data.azure_openai_config.url,
"key": form_data.azure_openai_config.key,
"version": form_data.azure_openai_config.version,
}
# Update the embedding function
if not rag_config["embedding_model"] in request.app.state.ef:
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
@ -363,10 +373,20 @@ async def update_embedding_config(
if rag_config["embedding_engine"] == "openai"
else rag_config["ollama_config"]["key"]
),
rag_config["embedding_batch_size"]
rag_config["embedding_batch_size"],
azure_api_version=(
rag_config["azure_openai_config"]["version"]
if rag_config["embedding_engine"] == "azure_openai"
else None
)
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
if rag_config["embedding_engine"] == "azure_openai":
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(
{rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")}
)
else:
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable reranking models
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
@ -387,9 +407,9 @@ async def update_embedding_config(
"embedding_batch_size": rag_config["embedding_batch_size"],
"openai_config": rag_config.get("openai_config", {}),
"ollama_config": rag_config.get("ollama_config", {}),
"azure_openai_config": rag_config.get("azure_openai_config", {}),
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
"message": "Embedding configuration updated in the database.",
}
else:
# Update the global configuration
@ -417,18 +437,21 @@ async def update_embedding_config(
gc.collect()
torch.cuda.empty_cache()
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
"ollama",
"openai",
"azure_openai",
]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
)
request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
"ollama",
"openai",
"azure_openai",
]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
)
request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
if form_data.ollama_config is not None:
request.app.state.config.RAG_OLLAMA_BASE_URL = (
@ -438,64 +461,20 @@ async def update_embedding_config(
form_data.ollama_config.key
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size
)
# Update the embedding function
if not form_data.embedding_model in request.app.state.ef:
@ -534,7 +513,12 @@ async def update_embedding_config(
),
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai":
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(
{request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION}
)
else:
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
# add model to state for selectable embedding models
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
@ -1541,9 +1525,9 @@ def save_docs_to_vector_db(
log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef,
embedding_engine,
embedding_model,
request.app.state.ef[embedding_model],
(
openai_api_base_url
if embedding_engine == "openai"
@ -1554,9 +1538,9 @@ def save_docs_to_vector_db(
)
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
openai_api_key
if embedding_engine == "openai"
else ollama_api_key
),
embedding_batch_size,
azure_api_version=(
@ -2370,7 +2354,24 @@ def query_doc_handler(
user=Depends(get_verified_user),
):
try:
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
# Try to get individual rag config for this collection
rag_config = {}
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use config from rag_config if present, else fallback to global config
enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K)
top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None)
if hybrid_bm25_weight is None:
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT)
if enable_hybrid:
collection_results = {}
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
collection_name=form_data.collection_name
@ -2379,32 +2380,23 @@ def query_doc_handler(
collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name],
query=form_data.query,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
query, prefix=prefix, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
form_data.r
if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD
),
hybrid_bm25_weight=(
form_data.hybrid_bm25_weight
if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT
),
k=top_k,
reranking_function=request.app.state.rf[reranking_model],
k_reranker=top_k_reranker,
r=relevance_threshold,
hybrid_bm25_weight=hybrid_bm25_weight,
user=user,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model](
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
k=top_k,
user=user,
)
except Exception as e:
@ -2436,11 +2428,10 @@ def query_collection_handler(
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
query, prefix=prefix, user=user
),
user=user,
ef=request.app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
reranking_function=request.app.state.rf,
k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER,
r=(
@ -2453,14 +2444,16 @@ def query_collection_handler(
if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT
),
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
reranking_model=request.app.state.config.RAG_RERANKING_MODEL,
)
else:
return query_collection(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
query, prefix=prefix, user=user
),
user=user,
ef=request.app.state.EMBEDDING_FUNCTION,
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
)