mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
Fix: Added compatibility of azure openai for individual rag config - fixed query doc handler and query collection handler to handle individual rag embedding functions
This commit is contained in:
parent
644cfa139b
commit
4c19aaaa64
@ -332,16 +332,26 @@ async def update_embedding_config(
|
||||
rag_config["embedding_model"] = form_data.embedding_model
|
||||
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
||||
|
||||
|
||||
rag_config["openai_config"] = {
|
||||
"url": form_data.openai_config.url,
|
||||
"key": form_data.openai_config.key,
|
||||
}
|
||||
|
||||
rag_config["ollama_config"] = {
|
||||
"url": form_data.ollama_config.url,
|
||||
"key": form_data.ollama_config.key,
|
||||
}
|
||||
# Update OpenAI, Ollama, and Azure OpenAI configurations if provided
|
||||
if form_data.openai_config is not None:
|
||||
rag_config["openai_config"] = {
|
||||
"url": form_data.openai_config.url,
|
||||
"key": form_data.openai_config.key,
|
||||
}
|
||||
|
||||
if form_data.ollama_config is not None:
|
||||
rag_config["ollama_config"] = {
|
||||
"url": form_data.ollama_config.url,
|
||||
"key": form_data.ollama_config.key,
|
||||
}
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
rag_config["azure_openai_config"] = {
|
||||
"url": form_data.azure_openai_config.url,
|
||||
"key": form_data.azure_openai_config.key,
|
||||
"version": form_data.azure_openai_config.version,
|
||||
}
|
||||
|
||||
# Update the embedding function
|
||||
if not rag_config["embedding_model"] in request.app.state.ef:
|
||||
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
||||
@ -363,10 +373,20 @@ async def update_embedding_config(
|
||||
if rag_config["embedding_engine"] == "openai"
|
||||
else rag_config["ollama_config"]["key"]
|
||||
),
|
||||
rag_config["embedding_batch_size"]
|
||||
rag_config["embedding_batch_size"],
|
||||
azure_api_version=(
|
||||
rag_config["azure_openai_config"]["version"]
|
||||
if rag_config["embedding_engine"] == "azure_openai"
|
||||
else None
|
||||
)
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||
if rag_config["embedding_engine"] == "azure_openai":
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(
|
||||
{rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")}
|
||||
)
|
||||
else:
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
# add model to state for selectable reranking models
|
||||
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
|
||||
@ -387,9 +407,9 @@ async def update_embedding_config(
|
||||
"embedding_batch_size": rag_config["embedding_batch_size"],
|
||||
"openai_config": rag_config.get("openai_config", {}),
|
||||
"ollama_config": rag_config.get("ollama_config", {}),
|
||||
"azure_openai_config": rag_config.get("azure_openai_config", {}),
|
||||
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
|
||||
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
|
||||
"message": "Embedding configuration updated in the database.",
|
||||
}
|
||||
else:
|
||||
# Update the global configuration
|
||||
@ -417,18 +437,21 @@ async def update_embedding_config(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
||||
"ollama",
|
||||
"openai",
|
||||
"azure_openai",
|
||||
]:
|
||||
if form_data.openai_config is not None:
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||
form_data.openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||||
form_data.openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
||||
"ollama",
|
||||
"openai",
|
||||
"azure_openai",
|
||||
]:
|
||||
if form_data.openai_config is not None:
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||
form_data.openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||||
form_data.openai_config.key
|
||||
)
|
||||
|
||||
if form_data.ollama_config is not None:
|
||||
request.app.state.config.RAG_OLLAMA_BASE_URL = (
|
||||
@ -438,64 +461,20 @@ async def update_embedding_config(
|
||||
form_data.ollama_config.key
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||
form_data.embedding_batch_size
|
||||
)
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||
form_data.embedding_batch_size
|
||||
)
|
||||
|
||||
# Update the embedding function
|
||||
if not form_data.embedding_model in request.app.state.ef:
|
||||
@ -534,7 +513,12 @@ async def update_embedding_config(
|
||||
),
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai":
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(
|
||||
{request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION}
|
||||
)
|
||||
else:
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
||||
# add model to state for selectable embedding models
|
||||
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
|
||||
@ -1541,9 +1525,9 @@ def save_docs_to_vector_db(
|
||||
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
embedding_function = get_embedding_function(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
request.app.state.ef,
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
request.app.state.ef[embedding_model],
|
||||
(
|
||||
openai_api_base_url
|
||||
if embedding_engine == "openai"
|
||||
@ -1554,9 +1538,9 @@ def save_docs_to_vector_db(
|
||||
)
|
||||
),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
openai_api_key
|
||||
if embedding_engine == "openai"
|
||||
else ollama_api_key
|
||||
),
|
||||
embedding_batch_size,
|
||||
azure_api_version=(
|
||||
@ -2370,7 +2354,24 @@ def query_doc_handler(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
try:
|
||||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
# Try to get individual rag config for this collection
|
||||
rag_config = {}
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Use config from rag_config if present, else fallback to global config
|
||||
enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
|
||||
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
|
||||
top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K)
|
||||
top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
|
||||
relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
||||
hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None)
|
||||
if hybrid_bm25_weight is None:
|
||||
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT)
|
||||
|
||||
if enable_hybrid:
|
||||
collection_results = {}
|
||||
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
||||
collection_name=form_data.collection_name
|
||||
@ -2379,32 +2380,23 @@ def query_doc_handler(
|
||||
collection_name=form_data.collection_name,
|
||||
collection_result=collection_results[form_data.collection_name],
|
||||
query=form_data.query,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
form_data.r
|
||||
if form_data.r
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
),
|
||||
hybrid_bm25_weight=(
|
||||
form_data.hybrid_bm25_weight
|
||||
if form_data.hybrid_bm25_weight
|
||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||
),
|
||||
k=top_k,
|
||||
reranking_function=request.app.state.rf[reranking_model],
|
||||
k_reranker=top_k_reranker,
|
||||
r=relevance_threshold,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
||||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
k=top_k,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -2436,11 +2428,10 @@ def query_collection_handler(
|
||||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
user=user,
|
||||
ef=request.app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
||||
reranking_function=request.app.state.rf,
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
@ -2453,14 +2444,16 @@ def query_collection_handler(
|
||||
if form_data.hybrid_bm25_weight
|
||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||
),
|
||||
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
reranking_model=request.app.state.config.RAG_RERANKING_MODEL,
|
||||
)
|
||||
else:
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
queries=[form_data.query],
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
user=user,
|
||||
ef=request.app.state.EMBEDDING_FUNCTION,
|
||||
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user