mirror of
https://github.com/open-webui/open-webui
synced 2025-06-25 17:57:20 +00:00
Fix: Added compatibility of azure openai for individual rag config - fixed query doc handler and query collection handler to handle individual rag embedding functions
This commit is contained in:
parent
644cfa139b
commit
4c19aaaa64
@ -332,16 +332,26 @@ async def update_embedding_config(
|
|||||||
rag_config["embedding_model"] = form_data.embedding_model
|
rag_config["embedding_model"] = form_data.embedding_model
|
||||||
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
||||||
|
|
||||||
|
# Update OpenAI, Ollama, and Azure OpenAI configurations if provided
|
||||||
|
if form_data.openai_config is not None:
|
||||||
rag_config["openai_config"] = {
|
rag_config["openai_config"] = {
|
||||||
"url": form_data.openai_config.url,
|
"url": form_data.openai_config.url,
|
||||||
"key": form_data.openai_config.key,
|
"key": form_data.openai_config.key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if form_data.ollama_config is not None:
|
||||||
rag_config["ollama_config"] = {
|
rag_config["ollama_config"] = {
|
||||||
"url": form_data.ollama_config.url,
|
"url": form_data.ollama_config.url,
|
||||||
"key": form_data.ollama_config.key,
|
"key": form_data.ollama_config.key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if form_data.azure_openai_config is not None:
|
||||||
|
rag_config["azure_openai_config"] = {
|
||||||
|
"url": form_data.azure_openai_config.url,
|
||||||
|
"key": form_data.azure_openai_config.key,
|
||||||
|
"version": form_data.azure_openai_config.version,
|
||||||
|
}
|
||||||
|
|
||||||
# Update the embedding function
|
# Update the embedding function
|
||||||
if not rag_config["embedding_model"] in request.app.state.ef:
|
if not rag_config["embedding_model"] in request.app.state.ef:
|
||||||
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
||||||
@ -363,9 +373,19 @@ async def update_embedding_config(
|
|||||||
if rag_config["embedding_engine"] == "openai"
|
if rag_config["embedding_engine"] == "openai"
|
||||||
else rag_config["ollama_config"]["key"]
|
else rag_config["ollama_config"]["key"]
|
||||||
),
|
),
|
||||||
rag_config["embedding_batch_size"]
|
rag_config["embedding_batch_size"],
|
||||||
|
azure_api_version=(
|
||||||
|
rag_config["azure_openai_config"]["version"]
|
||||||
|
if rag_config["embedding_engine"] == "azure_openai"
|
||||||
|
else None
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# add model to state for reloading on startup
|
# add model to state for reloading on startup
|
||||||
|
if rag_config["embedding_engine"] == "azure_openai":
|
||||||
|
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(
|
||||||
|
{rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")}
|
||||||
|
)
|
||||||
|
else:
|
||||||
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||||
# add model to state for selectable reranking models
|
# add model to state for selectable reranking models
|
||||||
@ -387,9 +407,9 @@ async def update_embedding_config(
|
|||||||
"embedding_batch_size": rag_config["embedding_batch_size"],
|
"embedding_batch_size": rag_config["embedding_batch_size"],
|
||||||
"openai_config": rag_config.get("openai_config", {}),
|
"openai_config": rag_config.get("openai_config", {}),
|
||||||
"ollama_config": rag_config.get("ollama_config", {}),
|
"ollama_config": rag_config.get("ollama_config", {}),
|
||||||
|
"azure_openai_config": rag_config.get("azure_openai_config", {}),
|
||||||
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
|
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
|
||||||
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
|
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
|
||||||
"message": "Embedding configuration updated in the database.",
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Update the global configuration
|
# Update the global configuration
|
||||||
@ -417,6 +437,9 @@ async def update_embedding_config(
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||||
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||||
|
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
||||||
"ollama",
|
"ollama",
|
||||||
"openai",
|
"openai",
|
||||||
@ -438,50 +461,6 @@ async def update_embedding_config(
|
|||||||
form_data.ollama_config.key
|
form_data.ollama_config.key
|
||||||
)
|
)
|
||||||
|
|
||||||
if form_data.azure_openai_config is not None:
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
|
||||||
form_data.azure_openai_config.url
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
|
||||||
form_data.azure_openai_config.key
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
|
||||||
form_data.azure_openai_config.version
|
|
||||||
)
|
|
||||||
|
|
||||||
if form_data.azure_openai_config is not None:
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
|
||||||
form_data.azure_openai_config.url
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
|
||||||
form_data.azure_openai_config.key
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
|
||||||
form_data.azure_openai_config.version
|
|
||||||
)
|
|
||||||
|
|
||||||
if form_data.azure_openai_config is not None:
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
|
||||||
form_data.azure_openai_config.url
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
|
||||||
form_data.azure_openai_config.key
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
|
||||||
form_data.azure_openai_config.version
|
|
||||||
)
|
|
||||||
|
|
||||||
if form_data.azure_openai_config is not None:
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
|
||||||
form_data.azure_openai_config.url
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
|
||||||
form_data.azure_openai_config.key
|
|
||||||
)
|
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
|
||||||
form_data.azure_openai_config.version
|
|
||||||
)
|
|
||||||
|
|
||||||
if form_data.azure_openai_config is not None:
|
if form_data.azure_openai_config is not None:
|
||||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||||
form_data.azure_openai_config.url
|
form_data.azure_openai_config.url
|
||||||
@ -534,6 +513,11 @@ async def update_embedding_config(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
# add model to state for reloading on startup
|
# add model to state for reloading on startup
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai":
|
||||||
|
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(
|
||||||
|
{request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION}
|
||||||
|
)
|
||||||
|
else:
|
||||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||||
# add model to state for selectable embedding models
|
# add model to state for selectable embedding models
|
||||||
@ -1541,9 +1525,9 @@ def save_docs_to_vector_db(
|
|||||||
|
|
||||||
log.info(f"adding to collection {collection_name}")
|
log.info(f"adding to collection {collection_name}")
|
||||||
embedding_function = get_embedding_function(
|
embedding_function = get_embedding_function(
|
||||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
embedding_engine,
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
embedding_model,
|
||||||
request.app.state.ef,
|
request.app.state.ef[embedding_model],
|
||||||
(
|
(
|
||||||
openai_api_base_url
|
openai_api_base_url
|
||||||
if embedding_engine == "openai"
|
if embedding_engine == "openai"
|
||||||
@ -1554,9 +1538,9 @@ def save_docs_to_vector_db(
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
request.app.state.config.RAG_OPENAI_API_KEY
|
openai_api_key
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if embedding_engine == "openai"
|
||||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
else ollama_api_key
|
||||||
),
|
),
|
||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
azure_api_version=(
|
azure_api_version=(
|
||||||
@ -2370,7 +2354,24 @@ def query_doc_handler(
|
|||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
# Try to get individual rag config for this collection
|
||||||
|
rag_config = {}
|
||||||
|
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
|
||||||
|
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||||
|
rag_config = knowledge_base.rag_config
|
||||||
|
|
||||||
|
# Use config from rag_config if present, else fallback to global config
|
||||||
|
enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
|
||||||
|
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||||
|
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
|
||||||
|
top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K)
|
||||||
|
top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
|
||||||
|
relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
||||||
|
hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None)
|
||||||
|
if hybrid_bm25_weight is None:
|
||||||
|
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT)
|
||||||
|
|
||||||
|
if enable_hybrid:
|
||||||
collection_results = {}
|
collection_results = {}
|
||||||
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
||||||
collection_name=form_data.collection_name
|
collection_name=form_data.collection_name
|
||||||
@ -2379,32 +2380,23 @@ def query_doc_handler(
|
|||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
collection_result=collection_results[form_data.collection_name],
|
collection_result=collection_results[form_data.collection_name],
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
||||||
query, prefix=prefix, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=top_k,
|
||||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
reranking_function=request.app.state.rf[reranking_model],
|
||||||
k_reranker=form_data.k_reranker
|
k_reranker=top_k_reranker,
|
||||||
or request.app.state.config.TOP_K_RERANKER,
|
r=relevance_threshold,
|
||||||
r=(
|
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||||
form_data.r
|
|
||||||
if form_data.r
|
|
||||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
|
||||||
),
|
|
||||||
hybrid_bm25_weight=(
|
|
||||||
form_data.hybrid_bm25_weight
|
|
||||||
if form_data.hybrid_bm25_weight
|
|
||||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
|
||||||
),
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_doc(
|
return query_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
||||||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=top_k,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -2436,11 +2428,10 @@ def query_collection_handler(
|
|||||||
return query_collection_with_hybrid_search(
|
return query_collection_with_hybrid_search(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
user=user,
|
||||||
query, prefix=prefix, user=user
|
ef=request.app.state.EMBEDDING_FUNCTION,
|
||||||
),
|
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
reranking_function=request.app.state.rf,
|
||||||
k_reranker=form_data.k_reranker
|
k_reranker=form_data.k_reranker
|
||||||
or request.app.state.config.TOP_K_RERANKER,
|
or request.app.state.config.TOP_K_RERANKER,
|
||||||
r=(
|
r=(
|
||||||
@ -2453,14 +2444,16 @@ def query_collection_handler(
|
|||||||
if form_data.hybrid_bm25_weight
|
if form_data.hybrid_bm25_weight
|
||||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||||
),
|
),
|
||||||
|
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
reranking_model=request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_collection(
|
return query_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
user=user,
|
||||||
query, prefix=prefix, user=user
|
ef=request.app.state.EMBEDDING_FUNCTION,
|
||||||
),
|
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user