mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
Refactoring: Adjusted to newly added rag_config column
This commit is contained in:
parent
0b9ed1ee42
commit
0d2eefd83d
@ -213,9 +213,9 @@ async def get_embedding_config(request: Request, collectionForm: CollectionNameF
|
||||
Otherwise, return the embedding configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the embedding configuration from the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
rag_config = knowledge_base.rag_config
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
|
||||
@ -256,9 +256,9 @@ async def get_reranking_config(request: Request, collectionForm: CollectionNameF
|
||||
Otherwise, return the reranking configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the reranking configuration from the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
rag_config = knowledge_base.rag_config
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": rag_config.get("reranking_model", request.app.state.config.RAG_RERANKING_MODEL),
|
||||
@ -287,77 +287,137 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/embedding/update")
|
||||
async def update_embedding_config(
|
||||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# TODO Update for individual rag config
|
||||
log.info(
|
||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
"""
|
||||
Update the embedding model configuration.
|
||||
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
||||
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
||||
"""
|
||||
try:
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.rag_config
|
||||
log.info(
|
||||
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
|
||||
)
|
||||
# Update embedding-related fields
|
||||
rag_config["embedding_engine"] = form_data.embedding_engine
|
||||
rag_config["embedding_model"] = form_data.embedding_model
|
||||
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
||||
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config is not None:
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||
form_data.openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||||
form_data.openai_config.key
|
||||
)
|
||||
rag_config["openai_config"] = {
|
||||
"url": form_data.openai_config.url,
|
||||
"key": form_data.openai_config.key,
|
||||
}
|
||||
|
||||
if form_data.ollama_config is not None:
|
||||
request.app.state.config.RAG_OLLAMA_BASE_URL = (
|
||||
form_data.ollama_config.url
|
||||
rag_config["ollama_config"] = {
|
||||
"url": form_data.ollama_config.url,
|
||||
"key": form_data.ollama_config.key,
|
||||
}
|
||||
# Update the embedding function
|
||||
if not request.app.state.ef.get("embedding_model"):
|
||||
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
||||
rag_config["embedding_engine"],
|
||||
rag_config["embedding_model"],
|
||||
)
|
||||
request.app.state.config.RAG_OLLAMA_API_KEY = (
|
||||
form_data.ollama_config.key
|
||||
|
||||
request.app.state.EMBEDDING_FUNCTION["embedding_model"] = get_embedding_function(
|
||||
rag_config["embedding_engine"],
|
||||
rag_config["embedding_model"],
|
||||
request.app.state.ef[rag_config["embedding_model"]],
|
||||
(
|
||||
rag_config["openai_config"]["url"]
|
||||
if rag_config["embedding_engine"] == "openai"
|
||||
else rag_config["ollama_config"]["url"]
|
||||
),
|
||||
(
|
||||
rag_config["openai_config"]["key"]
|
||||
if rag_config["embedding_engine"] == "openai"
|
||||
else rag_config["ollama_config"]["key"]
|
||||
),
|
||||
rag_config["embedding_batch_size"]
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||
form_data.embedding_batch_size
|
||||
|
||||
# Save the updated configuration to the database
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=form_data.collection_name, data={"rag_config": rag_config}
|
||||
)
|
||||
|
||||
request.app.state.ef = get_ef(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": rag_config["embedding_engine"],
|
||||
"embedding_model": rag_config["embedding_model"],
|
||||
"embedding_batch_size": rag_config["embedding_batch_size"],
|
||||
"openai_config": rag_config.get("openai_config", {}),
|
||||
"ollama_config": rag_config.get("ollama_config", {}),
|
||||
"message": "Embedding configuration updated in the database.",
|
||||
}
|
||||
else:
|
||||
# Update the global configuration
|
||||
log.info(
|
||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
request.app.state.ef,
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||
),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
if form_data.openai_config is not None:
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
request.app.state.config.RAG_OPENAI_API_KEY = form_data.openai_config.key
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||
},
|
||||
"ollama_config": {
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
},
|
||||
}
|
||||
if form_data.ollama_config is not None:
|
||||
request.app.state.config.RAG_OLLAMA_BASE_URL = form_data.ollama_config.url
|
||||
request.app.state.config.RAG_OLLAMA_API_KEY = form_data.ollama_config.key
|
||||
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||
|
||||
# Update the embedding function
|
||||
if not request.app.state.ef.get(form_data.embedding_model):
|
||||
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function(
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL],
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||
),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||
},
|
||||
"ollama_config": {
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
},
|
||||
"message": "Embedding configuration updated globally.",
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
raise HTTPException(
|
||||
@ -381,14 +441,27 @@ async def update_reranking_config(
|
||||
"""
|
||||
try:
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(form_data.collection_name)
|
||||
# TODO UPdate reranking accoridngly
|
||||
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
rag_config = knowledge_base.rag_config
|
||||
log.info(
|
||||
f"Updating reranking model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
|
||||
)
|
||||
rag_config["reranking_model"] = form_data.reranking_model
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=knowledge_base.id, data={"rag_config": rag_config}
|
||||
)
|
||||
try:
|
||||
if not request.app.state.rf.get(rag_config["reranking_model"]):
|
||||
request.app.state.rf[rag_config["reranking_model"]] = get_rf(
|
||||
rag_config["reranking_model"],
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
rag_config["ENABLE_RAG_HYBRID_SEARCH"] = False
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": rag_config["reranking_model"],
|
||||
@ -402,10 +475,13 @@ async def update_reranking_config(
|
||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||
|
||||
try:
|
||||
request.app.state.rf = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
True,
|
||||
)
|
||||
if request.app.state.rf.get(form_data.reranking_model):
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = request.app.state.rf[form_data.reranking_model]
|
||||
else:
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
True,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reranking model: {e}")
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
@ -431,9 +507,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
|
||||
Otherwise, return the RAG configuration stored in the database.
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Return the RAG configuration from the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
rag_config = knowledge_base.rag_config
|
||||
web_config = rag_config.get("web", {})
|
||||
return {
|
||||
"status": True,
|
||||
@ -700,9 +776,9 @@ async def update_rag_config(
|
||||
"""
|
||||
knowledge_base = Knowledges.get_knowledge_by_collection_name(collectionForm.collection_name)
|
||||
|
||||
if knowledge_base and not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
# Update the RAG configuration in the database
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Update only the provided fields in the rag_config
|
||||
for field, value in form_data.model_dump(exclude_unset=True).items():
|
||||
@ -710,7 +786,9 @@ async def update_rag_config(
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
|
||||
else:
|
||||
rag_config[field] = value
|
||||
|
||||
if not rag_config.get("ENABLE_RAG_HYBRID_SEARCH", True):
|
||||
request.app.state.rf[rag_config["reranking_model"]] = None
|
||||
|
||||
Knowledges.update_knowledge_data_by_id(
|
||||
id=knowledge_base.id, data={"rag_config": rag_config}
|
||||
)
|
||||
@ -748,7 +826,7 @@ async def update_rag_config(
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf = None
|
||||
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = None
|
||||
|
||||
request.app.state.config.TOP_K_RERANKER = (
|
||||
form_data.TOP_K_RERANKER
|
||||
@ -1052,15 +1130,15 @@ def save_docs_to_vector_db(
|
||||
log.info(
|
||||
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
||||
)
|
||||
|
||||
|
||||
rag_config = {}
|
||||
# Retrieve the knowledge base using the collection_name
|
||||
if knowledge_id:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
|
||||
# Retrieve the RAG configuration
|
||||
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
print("RAG CONFIG: ", rag_config)
|
||||
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Use knowledge-base-specific or default configurations
|
||||
text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER)
|
||||
chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE)
|
||||
@ -1072,7 +1150,7 @@ def save_docs_to_vector_db(
|
||||
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
|
||||
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
|
||||
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
|
||||
|
||||
|
||||
# Check if entries with the same hash (metadata.hash) already exist
|
||||
if metadata and "hash" in metadata:
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
@ -1156,7 +1234,7 @@ def save_docs_to_vector_db(
|
||||
embedding_function = get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
request.app.state.ef,
|
||||
request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
|
||||
(
|
||||
openai_api_base_url
|
||||
if embedding_engine == "openai"
|
||||
@ -1224,16 +1302,16 @@ def process_file(
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
|
||||
|
||||
# Retrieve the RAG configuration
|
||||
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db
|
||||
|
||||
elif form_data.knowledge_id:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
||||
|
||||
# Retrieve the RAG configuration
|
||||
if not knowledge_base.data.get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
|
||||
# Use knowledge-base-specific or default configurations
|
||||
content_extraction_engine = rag_config.get(
|
||||
@ -1906,7 +1984,7 @@ def query_doc_handler(
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
@ -1957,7 +2035,7 @@ def query_collection_handler(
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
reranking_function=request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL],
|
||||
k_reranker=form_data.k_reranker
|
||||
or request.app.state.config.TOP_K_RERANKER,
|
||||
r=(
|
||||
|
Loading…
Reference in New Issue
Block a user