This commit is contained in:
Maytown
2025-06-22 13:34:53 +02:00
committed by GitHub
17 changed files with 4383 additions and 965 deletions

View File

@@ -2336,6 +2336,44 @@ YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
os.getenv("YOUTUBE_LOADER_PROXY_URL", ""),
)
DEFAULT_RAG_SETTINGS = PersistentConfig(
"DEFAULT_RAG_SETTINGS",
"rag.default_settings",
os.getenv("DEFAULT_RAG_SETTINGS", "True").lower() == "true",
)
DOWNLOADED_EMBEDDING_MODELS = PersistentConfig(
"DOWNLOADED_EMBEDDING_MODELS",
"rag.downloaded_embedding_models",
os.getenv("DOWNLOADED_EMBEDDING_MODELS", {"":["sentence-transformers/all-MiniLM-L6-v2"],
"openai":["text-embedding-3-small"],
"ollama":[],
"azure_openai": []})
)
DOWNLOADED_RERANKING_MODELS = PersistentConfig(
"DOWNLOADED_RERANKING_MODELS",
"rag.downloaded_reranking_models",
os.getenv("DOWNLOADED_RERANKING_MODELS", {"":[],
"external":[]})
)
LOADED_EMBEDDING_MODELS = PersistentConfig(
"LOADED_EMBEDDING_MODELS",
"rag.loaded_embedding_models",
os.getenv("LOADED_EMBEDDING_MODELS", {"":["sentence-transformers/all-MiniLM-L6-v2"],
"openai":[],
"ollama":[],
"azure_openai": []})
)
LOADED_RERANKING_MODELS = PersistentConfig(
"LOADED_RERANKING_MODELS",
"rag.loaded_reranking_models",
os.getenv("LOADED_RERANKING_MODELS", {"":[],
"external":[]})
)
####################################
# Web Search (RAG)

View File

@@ -0,0 +1,84 @@
"""Peewee migrations -- 019_add_rag_config_to_knowledge.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
"""Add rag_config field to knowledge table if not present."""
from contextlib import suppress
from peewee_migrate import Migrator
import peewee as pw
import json
# Try importing JSONField from playhouse.postgres_ext
with suppress(ImportError):
from playhouse.postgres_ext import JSONField as PostgresJSONField
# Fallback JSONField for SQLite (stores JSON as text)
class SQLiteJSONField(pw.TextField):
def db_value(self, value):
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
return None
def get_compatible_json_field(database: pw.Database):
"""Return a JSON-compatible field for the current database."""
if isinstance(database, pw.SqliteDatabase):
return SQLiteJSONField(null=False, default={"DEFAULT_RAG_SETTINGS": True})
else:
return PostgresJSONField(null=False, default={"DEFAULT_RAG_SETTINGS": True})
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add rag_config JSON field to knowledge table"""
if 'knowledge' not in database.get_tables():
print("Knowledge table hasn't been created yet, skipping migration.")
return
class Knowledge(pw.Model):
class Meta:
table_name = 'knowledge'
Knowledge._meta.database = database # bind DB
migrator.add_fields(
Knowledge,
rag_config=get_compatible_json_field(database)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove rag_config field from knowledge table."""
if 'knowledge' not in database.get_tables():
print("Knowledge table hasn't been created yet, skipping migration.")
return
class Knowledge(pw.Model):
class Meta:
table_name = 'knowledge'
Knowledge._meta.database = database
migrator.remove_fields(Knowledge, 'rag_config')

View File

@@ -250,6 +250,11 @@ from open_webui.config import (
PDF_EXTRACT_IMAGES,
YOUTUBE_LOADER_LANGUAGE,
YOUTUBE_LOADER_PROXY_URL,
DEFAULT_RAG_SETTINGS,
DOWNLOADED_EMBEDDING_MODELS,
DOWNLOADED_RERANKING_MODELS,
LOADED_EMBEDDING_MODELS,
LOADED_RERANKING_MODELS,
# Retrieval (Web Search)
ENABLE_WEB_SEARCH,
WEB_SEARCH_ENGINE,
@@ -836,6 +841,11 @@ app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY
app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL
app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY
app.state.config.DEFAULT_RAG_SETTINGS = DEFAULT_RAG_SETTINGS
app.state.config.DOWNLOADED_EMBEDDING_MODELS = DOWNLOADED_EMBEDDING_MODELS
app.state.config.DOWNLOADED_RERANKING_MODELS = DOWNLOADED_RERANKING_MODELS
app.state.config.LOADED_EMBEDDING_MODELS = LOADED_EMBEDDING_MODELS
app.state.config.LOADED_RERANKING_MODELS = LOADED_RERANKING_MODELS
app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
@@ -843,62 +853,61 @@ app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
app.state.EMBEDDING_FUNCTION = None
app.state.ef = None
app.state.rf = None
app.state.EMBEDDING_FUNCTION = {}
app.state.ef = {}
app.state.rf = {}
app.state.YOUTUBE_LOADER_TRANSLATION = None
try:
app.state.ef = get_ef(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
# Load all embedding models that are currently in use
for engine, model_list in app.state.config.LOADED_EMBEDDING_MODELS.items():
for model in model_list:
if engine == "azure_openai":
# For Azure OpenAI, model is a dict: {model_name: version}
model_name, azure_openai_api_version = next(iter(model.items()))
model = model_name
app.state.ef[model] = get_ef(
engine,
model,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION[model] = get_embedding_function(
engine,
model,
app.state.ef[model],
(
app.state.config.RAG_OPENAI_API_BASE_URL
if engine == "openai"
else app.state.config.RAG_OLLAMA_BASE_URL
),
(
app.state.config.RAG_OPENAI_API_KEY
if engine == "openai"
else app.state.config.RAG_OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION
if engine == "azure_openai"
else None
),
)
# Load all reranking models that are currently in use
for engine, model_list in app.state.config.LOADED_RERANKING_MODELS.items():
for model in model_list:
app.state.rf[model["RAG_RERANKING_MODEL"]] = get_rf(
engine,
model["RAG_RERANKING_MODEL"],
model["RAG_EXTERNAL_RERANKER_URL"],
model["RAG_EXTERNAL_RERANKER_API_KEY"],
)
app.state.rf = get_rf(
app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL,
app.state.config.RAG_EXTERNAL_RERANKER_URL,
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
except Exception as e:
log.error(f"Error updating models: {e}")
pass
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.ef,
(
app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else (
app.state.config.RAG_OLLAMA_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else (
app.state.config.RAG_OLLAMA_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
########################################
#
# CODE EXECUTION

View File

@@ -35,6 +35,7 @@ class Knowledge(Base):
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
rag_config = Column(JSON, nullable=True) # Configuration for RAG (Retrieval-Augmented Generation) model.
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
@@ -68,6 +69,7 @@ class KnowledgeModel(BaseModel):
data: Optional[dict] = None
meta: Optional[dict] = None
rag_config: Optional[dict] = None # Configuration for RAG (Retrieval-Augmented Generation) model.
access_control: Optional[dict] = None
@@ -97,6 +99,7 @@ class KnowledgeForm(BaseModel):
description: str
data: Optional[dict] = None
access_control: Optional[dict] = None
rag_config: Optional[dict] = {}
class KnowledgeTable:
@@ -217,5 +220,49 @@ class KnowledgeTable:
except Exception:
return False
def update_rag_config_by_id(
self, id: str, rag_config: dict
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
"rag_config": rag_config,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def is_model_in_use_elsewhere(
self, model: str, model_type: str, id: Optional[str] = None
) -> bool:
try:
from sqlalchemy import func
with get_db() as db:
if db.bind.dialect.name == "sqlite":
query = db.query(Knowledge).filter(
func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model
)
elif db.bind.dialect.name == "postgresql":
query = db.query(Knowledge).filter(
Knowledge.rag_config.op("->>")(model_type) == model,
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
if id:
query = query.filter(Knowledge.id != id)
return query.first() is not None
except Exception as e:
log.exception(f"Error checking model usage elsewhere: {e}")
return False
Knowledges = KnowledgeTable()

View File

@@ -270,13 +270,15 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
def query_collection(
collection_names: list[str],
queries: list[str],
embedding_function,
user,
ef,
embedding_model,
k: int,
) -> dict:
results = []
error = False
def process_query_collection(collection_name, query_embedding):
def process_query_collection(collection_name, query_embedding, k):
try:
if collection_name:
result = query_doc(
@@ -291,18 +293,30 @@ def query_collection(
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
from open_webui.models.knowledge import Knowledges
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
for collection_name in collection_names:
rag_config = {}
knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
embedding_model = rag_config.get("embedding_model", embedding_model)
k = rag_config.get("TOP_K", k)
embedding_function=lambda query, prefix: ef[embedding_model](
query, prefix=prefix, user=user
)
# Generate embeddings for each query using the collection's embedding function
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
for query_embedding in query_embeddings:
result = executor.submit(
process_query_collection, collection_name, query_embedding
process_query_collection, collection_name, query_embedding, k
)
future_results.append(result)
task_results = [future.result() for future in future_results]
@@ -322,12 +336,14 @@ def query_collection(
def query_collection_with_hybrid_search(
collection_names: list[str],
queries: list[str],
embedding_function,
user,
ef,
k: int,
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
embedding_model: str,
) -> dict:
results = []
error = False
@@ -352,13 +368,32 @@ def query_collection_with_hybrid_search(
def process_query(collection_name, query):
try:
from open_webui.models.knowledge import Knowledges
# Use Knowledges to get per-collection RAG config
knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
rag_config = knowledge_base.rag_config
# Use config from rag_config if present, else fallback to global config
embedding_model = rag_config.get("embedding_model", embedding_model)
reranking_model = rag_config.get("reranking_function", reranking_model)
k = rag_config.get("TOP_K", k)
k_reranker = rag_config.get("TOP_K_RERANKER", k_reranker)
r = rag_config.get("RELEVANCE_THRESHOLD", r)
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", hybrid_bm25_weight)
embedding_function=lambda query, prefix: ef[embedding_model](
query, prefix=prefix, user=user
),
result = query_doc_with_hybrid_search(
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
reranking_function=reranking_function[reranking_model],
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
@@ -446,7 +481,8 @@ def get_sources_from_files(
request,
files,
queries,
embedding_function,
user,
ef,
k,
reranking_function,
k_reranker,
@@ -454,9 +490,10 @@ def get_sources_from_files(
hybrid_bm25_weight,
hybrid_search,
full_context=False,
embedding_model=None
):
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
f"files: {files} {queries} {ef[embedding_model]} {reranking_function} {full_context}"
)
extracted_collections = []
@@ -564,12 +601,14 @@ def get_sources_from_files(
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
user=user,
ef=ef,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
embedding_model=embedding_model,
)
except Exception as e:
log.debug(
@@ -581,8 +620,10 @@ def get_sources_from_files(
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
user=user,
ef=ef,
k=k,
embedding_model=embedding_model
)
except Exception as e:
log.exception(e)

View File

@@ -17,6 +17,7 @@ from fastapi import (
UploadFile,
status,
Query,
Form
)
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
@@ -90,6 +91,7 @@ def upload_file(
process: bool = Query(True),
internal: bool = False,
user=Depends(get_verified_user),
knowledge_id: Optional[str] = Form(None)
):
log.info(f"file.content_type: {file.content_type}")
@@ -173,18 +175,18 @@ def upload_file(
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
ProcessFileForm(file_id=id, content=result.get("text", ""), knowledge_id=knowledge_id),
user=user,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=id), user=user)
process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file(request, ProcessFileForm(file_id=id), user=user)
process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user)
file_item = Files.get_file_by_id(id=id)
except Exception as e:

View File

@@ -241,6 +241,69 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
)
return True
@router.post("/reindex/{id}", response_model=bool)
async def reindex_specific_knowledge_files(request: Request, id: str, user=Depends(get_verified_user)):
log.info(f"reindex_specific_knowledge_files called with id={id}")
knowledge_base = Knowledges.get_knowledge_by_id(id=id)
deleted_knowledge_bases = []
# -- Robust error handling for missing or invalid data
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
log.warning(
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
)
try:
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
deleted_knowledge_bases.append(knowledge_base.id)
except Exception as e:
log.error(
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
)
try:
file_ids = knowledge_base.data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids)
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
VECTOR_DB_CLIENT.delete_collection(
collection_name=knowledge_base.id
)
except Exception as e:
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
failed_files = []
for file in files:
try:
process_file(
request,
ProcessFileForm(
file_id=file.id, collection_name=knowledge_base.id
),
user=user,
)
except Exception as e:
log.error(
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
)
failed_files.append({"file_id": file.id, "error": str(e)})
continue
if failed_files:
log.warning(
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
)
for failed in failed_files:
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
except Exception as e:
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
log.info(
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
)
return True
############################
# GetKnowledgeById

View File

@@ -17,7 +17,7 @@ router = APIRouter()
@router.get("/ef")
async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
return {"result": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]("hello world")}
############################
@@ -57,7 +57,7 @@ async def add_memory(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
"vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
memory.content, user=user
),
"metadata": {"created_at": memory.created_at},
@@ -84,7 +84,7 @@ async def query_memory(
):
results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
vectors=[request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](form_data.content, user=user)],
limit=form_data.k,
)
@@ -107,7 +107,7 @@ async def reset_memory_from_vector_db(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
"vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
memory.content, user=user
),
"metadata": {
@@ -166,7 +166,7 @@ async def update_memory_by_id(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
"vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
memory.content, user=user
),
"metadata": {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,338 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestRagConfig(AbstractPostgresTest):
BASE_PATH = "/api/v1/config"
def setup_class(cls):
super().setup_class()
from open_webui.models.knowledge import Knowledges
cls.knowledges = Knowledges
def setup_method(self):
super().setup_method()
# Insert a knowledge base with default settings
self.knowledges.insert_new_knowledge(
id="1",
name="Default KB",
rag_config={
"DEFAULT_RAG_SETTINGS": True,
"TEMPLATE": "default-template",
"TOP_K": 5,
},
)
# Insert a knowledge base with custom RAG config
self.knowledges.insert_new_knowledge(
id="2",
name="Custom KB",
rag_config={
"DEFAULT_RAG_SETTINGS": False,
"TEMPLATE": "custom-template",
"TOP_K": 10,
"web": {
"ENABLE_WEB_SEARCH": True,
"WEB_SEARCH_ENGINE": "custom-engine"
}
},
)
def test_get_rag_config_default(self):
# Should return default config for knowledge base with DEFAULT_RAG_SETTINGS True
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url(""),
json={"knowledge_id": "1"}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["RAG_TEMPLATE"] == "default-template"
assert data["TOP_K"] == 5
assert data["DEFAULT_RAG_SETTINGS"] is True
def test_get_rag_config_individual(self):
# Should return custom config for knowledge base with DEFAULT_RAG_SETTINGS False
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url(""),
json={"knowledge_id": "2"}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["RAG_TEMPLATE"] == "custom-template"
assert data["TOP_K"] == 10
assert data["DEFAULT_RAG_SETTINGS"] is False
assert data["web"]["ENABLE_WEB_SEARCH"] is True
assert data["web"]["WEB_SEARCH_ENGINE"] == "custom-engine"
def test_get_rag_config_unauthorized(self):
# Should return 401 if not authenticated
response = self.fast_api_client.post(
self.create_url(""),
json={"knowledge_id": "1"}
)
assert response.status_code == 401
def test_update_rag_config_default(self):
# Should update the global config for knowledge base with DEFAULT_RAG_SETTINGS True
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/update"),
json={
"knowledge_id": "1",
"RAG_TEMPLATE": "updated-template",
"TOP_K": 42,
"ENABLE_RAG_HYBRID_SEARCH": False,
}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["RAG_TEMPLATE"] == "updated-template"
assert data["TOP_K"] == 42
assert data["ENABLE_RAG_HYBRID_SEARCH"] is False
def test_update_rag_config_individual(self):
# Should update the config for knowledge base with DEFAULT_RAG_SETTINGS False
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/update"),
json={
"knowledge_id": "2",
"TEMPLATE": "individual-updated",
"TOP_K": 99,
"web": {
"ENABLE_WEB_SEARCH": False,
"WEB_SEARCH_ENGINE": "updated-engine"
}
}
)
assert response.status_code == 200
data = response.json()
assert data["TEMPLATE"] == "individual-updated"
assert data["TOP_K"] == 99
assert data["web"]["ENABLE_WEB_SEARCH"] is False
assert data["web"]["WEB_SEARCH_ENGINE"] == "updated-engine"
def test_update_reranking_model_and_states_individual(self):
# Simulate app state for reranking models
app = self.fast_api_client.app
app.state.rf = {}
app.state.config.LOADED_RERANKING_MODELS = {"": [], "external": []}
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": [], "external": []}
# Update individual config with new reranking model
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/update"),
json={
"knowledge_id": "2",
"RAG_RERANKING_MODEL": "",
"RAG_RERANKING_ENGINE": "",
"RAG_EXTERNAL_RERANKER_URL": "",
"RAG_EXTERNAL_RERANKER_API_KEY": "",
"ENABLE_RAG_HYBRID_SEARCH": True,
}
)
assert response.status_code == 200
data = response.json()
# Model should be in loaded and downloaded models
loaded = app.state.config.LOADED_RERANKING_MODELS[""]
downloaded = app.state.config.DOWNLOADED_RERANKING_MODELS[""]
assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in loaded)
assert "BBAI/bge-reranker-v2-m3" in downloaded
assert "BBAI/bge-reranker-v2-m3" in app.state.rf
def test_update_reranking_model_and_states_default(self):
# Simulate app state for reranking models
app = self.fast_api_client.app
app.state.rf = {}
app.state.config.LOADED_RERANKING_MODELS = {"": [], "external": []}
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": [], "external": []}
# Update default config with new reranking model
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/update"),
json={
"knowledge_id": "1",
"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3",
"RAG_RERANKING_ENGINE": "",
"RAG_EXTERNAL_RERANKER_URL": "",
"RAG_EXTERNAL_RERANKER_API_KEY": "",
"ENABLE_RAG_HYBRID_SEARCH": True,
}
)
assert response.status_code == 200
data = response.json()
loaded = app.state.config.LOADED_RERANKING_MODELS[""]
downloaded = app.state.config.DOWNLOADED_RERANKING_MODELS[""]
assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in loaded)
assert "BBAI/bge-reranker-v2-m3" in downloaded
assert "BBAI/bge-reranker-v2-m3" in app.state.rf
def test_update_rag_config_unauthorized(self):
# Should return 401 if not authenticated
response = self.fast_api_client.post(
self.create_url("/update"),
json={"knowledge_id": "1", "RAG_TEMPLATE": "should-not-update"}
)
assert response.status_code == 401
def test_reranking_model_freed_only_if_not_in_use_elsewhere(self):
"""
Test that the reranking model is only deleted from state if no other knowledge base is using it.
"""
app = self.fast_api_client.app
app.state.rf = {"rerank-model-shared": object()}
app.state.config.LOADED_RERANKING_MODELS = {"": [{"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3"}]}
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": ["BBAI/bge-reranker-v2-m3"]}
# Patch is_model_in_use_elsewhere to simulate model still in use
from unittest.mock import patch
with patch("open_webui.models.knowledge.Knowledges.is_model_in_use_elsewhere", return_value=True):
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/update"),
json={
"knowledge_id": "2",
"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3",
"RAG_RERANKING_ENGINE": "",
"ENABLE_RAG_HYBRID_SEARCH": False,
}
)
assert response.status_code == 200
# Model should NOT be deleted from state
assert "rerank-model-shared" in app.state.rf
assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in app.state.config.LOADED_RERANKING_MODELS[""])
# Now simulate model NOT in use elsewhere
app.state.rf = {"rerank-model-shared": object()}
app.state.config.LOADED_RERANKING_MODELS = {"": [{"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3"}]}
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": ["BBAI/bge-reranker-v2-m3"]}
with patch("open_webui.models.knowledge.Knowledges.is_model_in_use_elsewhere", return_value=False):
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/update"),
json={
"knowledge_id": "2",
"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3",
"RAG_RERANKING_ENGINE": "",
"ENABLE_RAG_HYBRID_SEARCH": False,
}
)
assert response.status_code == 200
# Model should be deleted from state
assert "rerank-model-shared" not in app.state.rf
assert not any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in app.state.config.LOADED_RERANKING_MODELS[""])
def test_get_embedding_config_default(self):
# Should return default embedding config for knowledge base with DEFAULT_RAG_SETTINGS True
# First, add embedding config to the default KB
self.knowledges.update_rag_config_by_id(
id="1",
rag_config={
"DEFAULT_RAG_SETTINGS": True,
"embedding_engine": "",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"embedding_batch_size": 1,
"openai_config": {"url": "https://api.openai.com", "key": "default-key"},
"ollama_config": {"url": "http://localhost:11434", "key": "ollama-key"},
}
)
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/embedding"),
json={"knowledge_id": "1"}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["embedding_engine"] == ""
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert data["embedding_batch_size"] == 1
assert data["openai_config"]["url"] == "https://api.openai.com"
assert data["openai_config"]["key"] == "default-key"
assert data["ollama_config"]["url"] == "http://localhost:11434"
assert data["ollama_config"]["key"] == "ollama-key"
def test_get_embedding_config_individual(self):
# Should return custom embedding config for knowledge base with DEFAULT_RAG_SETTINGS False
self.knowledges.update_rag_config_by_id(
id="2",
rag_config={
"DEFAULT_RAG_SETTINGS": False,
"embedding_engine": "",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"embedding_batch_size": 2,
"openai_config": {"url": "https://custom.openai.com", "key": "custom-key"},
"ollama_config": {"url": "http://custom-ollama:11434", "key": "custom-ollama-key"},
}
)
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/embedding"),
json={"knowledge_id": "2"}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["embedding_engine"] == ""
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert data["embedding_batch_size"] == 2
assert data["openai_config"]["url"] == "https://custom.openai.com"
assert data["openai_config"]["key"] == "custom-key"
assert data["ollama_config"]["url"] == "http://custom-ollama:11434"
assert data["ollama_config"]["key"] == "custom-ollama-key"
def test_update_embedding_config_default(self):
# Should update the global embedding config for knowledge base with DEFAULT_RAG_SETTINGS True
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/embedding/update"),
json={
"knowledge_id": "1",
"embedding_engine": "",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"embedding_batch_size": 4,
"openai_config": {"url": "https://api.openai.com/v2", "key": "updated-key"},
"ollama_config": {"url": "http://localhost:11434", "key": "ollama-key"},
}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["embedding_engine"] == ""
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert data["embedding_batch_size"] == 4
assert data["openai_config"]["url"] == "https://api.openai.com/v2"
assert data["openai_config"]["key"] == "updated-key"
def test_update_embedding_config_individual(self):
# Should update the embedding config for knowledge base with DEFAULT_RAG_SETTINGS False
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/embedding/update"),
json={
"knowledge_id": "2",
"embedding_engine": "",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"embedding_batch_size": 8,
"openai_config": {"url": "https://custom.openai.com/v2", "key": "custom-key-2"},
"ollama_config": {"url": "http://custom-ollama:11434/v2", "key": "custom-ollama-key-2"},
}
)
assert response.status_code == 200
data = response.json()
assert data["status"] is True
assert data["embedding_engine"] == ""
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert data["embedding_batch_size"] == 8
assert data["openai_config"]["url"] == "https://custom.openai.com/v2"
assert data["openai_config"]["key"] == "custom-key-2"
assert data["ollama_config"]["url"] == "http://custom-ollama:11434/v2"
assert data["ollama_config"]["key"] == "custom-ollama-key-2"

View File

@@ -602,7 +602,7 @@ async def chat_image_generation_handler(
async def chat_completion_files_handler(
request: Request, body: dict, user: UserModel
request: Request, body: dict, user: UserModel, model_knowledge
) -> tuple[dict, dict[str, list]]:
sources = []
@@ -640,6 +640,21 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])]
try:
# check if individual rag config is used
rag_config = {}
if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
rag_config = model_knowledge[0].get("rag_config")
k=rag_config.get("TOP_K", request.app.state.config.TOP_K)
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
reranking_function=request.app.state.rf[reranking_model] if reranking_model else None
k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
r=rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
hybrid_bm25_weight=rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT)
embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL)
# Offload get_sources_from_files to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
@@ -649,16 +664,16 @@ async def chat_completion_files_handler(
request=request,
files=files,
queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user
),
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
user=user,
ef=request.app.state.EMBEDDING_FUNCTION,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
hybrid_search=hybrid_search,
full_context=full_context,
embedding_model=embedding_model,
),
)
except Exception as e:
@@ -917,7 +932,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.exception(e)
try:
form_data, flags = await chat_completion_files_handler(request, form_data, user)
form_data, flags = await chat_completion_files_handler(request, form_data, user, model_knowledge)
sources.extend(flags.get("sources", []))
except Exception as e:
log.exception(e)
@@ -958,20 +973,24 @@ async def process_chat_payload(request, form_data, user, metadata, model):
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
)
# Adjusted RAG template step to use knowledge-base-specific configuration
rag_template_config = request.app.state.config.RAG_TEMPLATE
if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
rag_template_config = model_knowledge[0].get("rag_config").get(
"RAG_TEMPLATE", request.app.state.config.RAG_TEMPLATE
)
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model.get("owned_by") == "ollama":
form_data["messages"] = prepend_to_first_user_message_content(
rag_template(
request.app.state.config.RAG_TEMPLATE, context_string, prompt
),
rag_template(rag_template_config, context_string, prompt),
form_data["messages"],
)
else:
form_data["messages"] = add_or_update_system_message(
rag_template(
request.app.state.config.RAG_TEMPLATE, context_string, prompt
),
rag_template(rag_template_config, context_string, prompt),
form_data["messages"],
)