mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
Merge cca2174e79
into aef0ad2d10
This commit is contained in:
commit
83c2566680
@ -2336,6 +2336,44 @@ YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
|
||||
os.getenv("YOUTUBE_LOADER_PROXY_URL", ""),
|
||||
)
|
||||
|
||||
DEFAULT_RAG_SETTINGS = PersistentConfig(
|
||||
"DEFAULT_RAG_SETTINGS",
|
||||
"rag.default_settings",
|
||||
os.getenv("DEFAULT_RAG_SETTINGS", "True").lower() == "true",
|
||||
)
|
||||
|
||||
DOWNLOADED_EMBEDDING_MODELS = PersistentConfig(
|
||||
"DOWNLOADED_EMBEDDING_MODELS",
|
||||
"rag.downloaded_embedding_models",
|
||||
os.getenv("DOWNLOADED_EMBEDDING_MODELS", {"":["sentence-transformers/all-MiniLM-L6-v2"],
|
||||
"openai":["text-embedding-3-small"],
|
||||
"ollama":[],
|
||||
"azure_openai": []})
|
||||
)
|
||||
|
||||
DOWNLOADED_RERANKING_MODELS = PersistentConfig(
|
||||
"DOWNLOADED_RERANKING_MODELS",
|
||||
"rag.downloaded_reranking_models",
|
||||
os.getenv("DOWNLOADED_RERANKING_MODELS", {"":[],
|
||||
"external":[]})
|
||||
)
|
||||
|
||||
LOADED_EMBEDDING_MODELS = PersistentConfig(
|
||||
"LOADED_EMBEDDING_MODELS",
|
||||
"rag.loaded_embedding_models",
|
||||
os.getenv("LOADED_EMBEDDING_MODELS", {"":["sentence-transformers/all-MiniLM-L6-v2"],
|
||||
"openai":[],
|
||||
"ollama":[],
|
||||
"azure_openai": []})
|
||||
)
|
||||
|
||||
LOADED_RERANKING_MODELS = PersistentConfig(
|
||||
"LOADED_RERANKING_MODELS",
|
||||
"rag.loaded_reranking_models",
|
||||
os.getenv("LOADED_RERANKING_MODELS", {"":[],
|
||||
"external":[]})
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Web Search (RAG)
|
||||
|
@ -0,0 +1,84 @@
|
||||
"""Peewee migrations -- 019_add_rag_config_to_knowledge.py.
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
"""Add rag_config field to knowledge table if not present."""
|
||||
from contextlib import suppress
|
||||
from peewee_migrate import Migrator
|
||||
import peewee as pw
|
||||
import json
|
||||
|
||||
# Try importing JSONField from playhouse.postgres_ext
|
||||
with suppress(ImportError):
|
||||
from playhouse.postgres_ext import JSONField as PostgresJSONField
|
||||
|
||||
|
||||
# Fallback JSONField for SQLite (stores JSON as text)
|
||||
class SQLiteJSONField(pw.TextField):
|
||||
def db_value(self, value):
|
||||
return json.dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
return None
|
||||
|
||||
|
||||
def get_compatible_json_field(database: pw.Database):
|
||||
"""Return a JSON-compatible field for the current database."""
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
return SQLiteJSONField(null=False, default={"DEFAULT_RAG_SETTINGS": True})
|
||||
else:
|
||||
return PostgresJSONField(null=False, default={"DEFAULT_RAG_SETTINGS": True})
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add rag_config JSON field to knowledge table"""
|
||||
if 'knowledge' not in database.get_tables():
|
||||
print("Knowledge table hasn't been created yet, skipping migration.")
|
||||
return
|
||||
|
||||
class Knowledge(pw.Model):
|
||||
class Meta:
|
||||
table_name = 'knowledge'
|
||||
|
||||
Knowledge._meta.database = database # bind DB
|
||||
|
||||
migrator.add_fields(
|
||||
Knowledge,
|
||||
rag_config=get_compatible_json_field(database)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove rag_config field from knowledge table."""
|
||||
if 'knowledge' not in database.get_tables():
|
||||
print("Knowledge table hasn't been created yet, skipping migration.")
|
||||
return
|
||||
|
||||
class Knowledge(pw.Model):
|
||||
class Meta:
|
||||
table_name = 'knowledge'
|
||||
|
||||
Knowledge._meta.database = database
|
||||
migrator.remove_fields(Knowledge, 'rag_config')
|
@ -250,6 +250,11 @@ from open_webui.config import (
|
||||
PDF_EXTRACT_IMAGES,
|
||||
YOUTUBE_LOADER_LANGUAGE,
|
||||
YOUTUBE_LOADER_PROXY_URL,
|
||||
DEFAULT_RAG_SETTINGS,
|
||||
DOWNLOADED_EMBEDDING_MODELS,
|
||||
DOWNLOADED_RERANKING_MODELS,
|
||||
LOADED_EMBEDDING_MODELS,
|
||||
LOADED_RERANKING_MODELS,
|
||||
# Retrieval (Web Search)
|
||||
ENABLE_WEB_SEARCH,
|
||||
WEB_SEARCH_ENGINE,
|
||||
@ -836,6 +841,11 @@ app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY
|
||||
app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL
|
||||
app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY
|
||||
|
||||
app.state.config.DEFAULT_RAG_SETTINGS = DEFAULT_RAG_SETTINGS
|
||||
app.state.config.DOWNLOADED_EMBEDDING_MODELS = DOWNLOADED_EMBEDDING_MODELS
|
||||
app.state.config.DOWNLOADED_RERANKING_MODELS = DOWNLOADED_RERANKING_MODELS
|
||||
app.state.config.LOADED_EMBEDDING_MODELS = LOADED_EMBEDDING_MODELS
|
||||
app.state.config.LOADED_RERANKING_MODELS = LOADED_RERANKING_MODELS
|
||||
|
||||
app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
|
||||
app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT
|
||||
@ -843,62 +853,61 @@ app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
|
||||
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
|
||||
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
app.state.ef = None
|
||||
app.state.rf = None
|
||||
app.state.EMBEDDING_FUNCTION = {}
|
||||
app.state.ef = {}
|
||||
app.state.rf = {}
|
||||
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||
|
||||
|
||||
try:
|
||||
app.state.ef = get_ef(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
# Load all embedding models that are currently in use
|
||||
for engine, model_list in app.state.config.LOADED_EMBEDDING_MODELS.items():
|
||||
for model in model_list:
|
||||
if engine == "azure_openai":
|
||||
# For Azure OpenAI, model is a dict: {model_name: version}
|
||||
model_name, azure_openai_api_version = next(iter(model.items()))
|
||||
model = model_name
|
||||
|
||||
app.state.ef[model] = get_ef(
|
||||
engine,
|
||||
model,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
app.state.EMBEDDING_FUNCTION[model] = get_embedding_function(
|
||||
engine,
|
||||
model,
|
||||
app.state.ef[model],
|
||||
(
|
||||
app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if engine == "openai"
|
||||
else app.state.config.RAG_OLLAMA_BASE_URL
|
||||
),
|
||||
(
|
||||
app.state.config.RAG_OPENAI_API_KEY
|
||||
if engine == "openai"
|
||||
else app.state.config.RAG_OLLAMA_API_KEY
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
azure_api_version=(
|
||||
app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||||
if engine == "azure_openai"
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Load all reranking models that are currently in use
|
||||
for engine, model_list in app.state.config.LOADED_RERANKING_MODELS.items():
|
||||
for model in model_list:
|
||||
app.state.rf[model["RAG_RERANKING_MODEL"]] = get_rf(
|
||||
engine,
|
||||
model["RAG_RERANKING_MODEL"],
|
||||
model["RAG_EXTERNAL_RERANKER_URL"],
|
||||
model["RAG_EXTERNAL_RERANKER_API_KEY"],
|
||||
)
|
||||
|
||||
app.state.rf = get_rf(
|
||||
app.state.config.RAG_RERANKING_ENGINE,
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error updating models: {e}")
|
||||
pass
|
||||
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.ef,
|
||||
(
|
||||
app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else (
|
||||
app.state.config.RAG_OLLAMA_BASE_URL
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||
)
|
||||
),
|
||||
(
|
||||
app.state.config.RAG_OPENAI_API_KEY
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else (
|
||||
app.state.config.RAG_OLLAMA_API_KEY
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||
else app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||
)
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
azure_api_version=(
|
||||
app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE EXECUTION
|
||||
|
@ -35,6 +35,7 @@ class Knowledge(Base):
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
rag_config = Column(JSON, nullable=True) # Configuration for RAG (Retrieval-Augmented Generation) model.
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
@ -68,6 +69,7 @@ class KnowledgeModel(BaseModel):
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
rag_config: Optional[dict] = None # Configuration for RAG (Retrieval-Augmented Generation) model.
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
@ -97,6 +99,7 @@ class KnowledgeForm(BaseModel):
|
||||
description: str
|
||||
data: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
rag_config: Optional[dict] = {}
|
||||
|
||||
|
||||
class KnowledgeTable:
|
||||
@ -217,5 +220,49 @@ class KnowledgeTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def update_rag_config_by_id(
|
||||
self, id: str, rag_config: dict
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
"rag_config": rag_config,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def is_model_in_use_elsewhere(
|
||||
self, model: str, model_type: str, id: Optional[str] = None
|
||||
) -> bool:
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
with get_db() as db:
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
query = db.query(Knowledge).filter(
|
||||
func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model
|
||||
)
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
query = db.query(Knowledge).filter(
|
||||
Knowledge.rag_config.op("->>")(model_type) == model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
if id:
|
||||
query = query.filter(Knowledge.id != id)
|
||||
|
||||
return query.first() is not None
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error checking model usage elsewhere: {e}")
|
||||
return False
|
||||
|
||||
Knowledges = KnowledgeTable()
|
||||
|
@ -270,13 +270,15 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
user,
|
||||
ef,
|
||||
embedding_model,
|
||||
k: int,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
|
||||
def process_query_collection(collection_name, query_embedding):
|
||||
def process_query_collection(collection_name, query_embedding, k):
|
||||
try:
|
||||
if collection_name:
|
||||
result = query_doc(
|
||||
@ -291,18 +293,30 @@ def query_collection(
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
return None, e
|
||||
|
||||
# Generate all query embeddings (in one call)
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
log.debug(
|
||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||
)
|
||||
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_results = []
|
||||
for query_embedding in query_embeddings:
|
||||
for collection_name in collection_names:
|
||||
for collection_name in collection_names:
|
||||
rag_config = {}
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
embedding_model = rag_config.get("embedding_model", embedding_model)
|
||||
k = rag_config.get("TOP_K", k)
|
||||
|
||||
embedding_function=lambda query, prefix: ef[embedding_model](
|
||||
query, prefix=prefix, user=user
|
||||
)
|
||||
# Generate embeddings for each query using the collection's embedding function
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
for query_embedding in query_embeddings:
|
||||
result = executor.submit(
|
||||
process_query_collection, collection_name, query_embedding
|
||||
process_query_collection, collection_name, query_embedding, k
|
||||
)
|
||||
future_results.append(result)
|
||||
task_results = [future.result() for future in future_results]
|
||||
@ -322,12 +336,14 @@ def query_collection(
|
||||
def query_collection_with_hybrid_search(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
user,
|
||||
ef,
|
||||
k: int,
|
||||
reranking_function,
|
||||
k_reranker: int,
|
||||
r: float,
|
||||
hybrid_bm25_weight: float,
|
||||
embedding_model: str,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
@ -352,13 +368,32 @@ def query_collection_with_hybrid_search(
|
||||
|
||||
def process_query(collection_name, query):
|
||||
try:
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
|
||||
# Use Knowledges to get per-collection RAG config
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(collection_name)
|
||||
|
||||
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = knowledge_base.rag_config
|
||||
# Use config from rag_config if present, else fallback to global config
|
||||
embedding_model = rag_config.get("embedding_model", embedding_model)
|
||||
reranking_model = rag_config.get("reranking_function", reranking_model)
|
||||
k = rag_config.get("TOP_K", k)
|
||||
k_reranker = rag_config.get("TOP_K_RERANKER", k_reranker)
|
||||
r = rag_config.get("RELEVANCE_THRESHOLD", r)
|
||||
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", hybrid_bm25_weight)
|
||||
|
||||
embedding_function=lambda query, prefix: ef[embedding_model](
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
collection_result=collection_results[collection_name],
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
reranking_function=reranking_function[reranking_model],
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
@ -446,7 +481,8 @@ def get_sources_from_files(
|
||||
request,
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
user,
|
||||
ef,
|
||||
k,
|
||||
reranking_function,
|
||||
k_reranker,
|
||||
@ -454,9 +490,10 @@ def get_sources_from_files(
|
||||
hybrid_bm25_weight,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
embedding_model=None
|
||||
):
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
f"files: {files} {queries} {ef[embedding_model]} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
@ -564,12 +601,14 @@ def get_sources_from_files(
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
user=user,
|
||||
ef=ef,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
@ -581,8 +620,10 @@ def get_sources_from_files(
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
user=user,
|
||||
ef=ef,
|
||||
k=k,
|
||||
embedding_model=embedding_model
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
@ -17,6 +17,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
Query,
|
||||
Form
|
||||
)
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@ -90,6 +91,7 @@ def upload_file(
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
knowledge_id: Optional[str] = Form(None)
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
@ -173,18 +175,18 @@ def upload_file(
|
||||
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
ProcessFileForm(file_id=id, content=result.get("text", ""), knowledge_id=knowledge_id),
|
||||
user=user,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
|
@ -241,6 +241,69 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
)
|
||||
return True
|
||||
|
||||
@router.post("/reindex/{id}", response_model=bool)
|
||||
async def reindex_specific_knowledge_files(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
log.info(f"reindex_specific_knowledge_files called with id={id}")
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(id=id)
|
||||
|
||||
deleted_knowledge_bases = []
|
||||
|
||||
# -- Robust error handling for missing or invalid data
|
||||
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
|
||||
log.warning(
|
||||
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
|
||||
)
|
||||
try:
|
||||
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
|
||||
deleted_knowledge_bases.append(knowledge_base.id)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
file_ids = knowledge_base.data.get("file_ids", [])
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||
VECTOR_DB_CLIENT.delete_collection(
|
||||
collection_name=knowledge_base.id
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
|
||||
|
||||
failed_files = []
|
||||
for file in files:
|
||||
try:
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(
|
||||
file_id=file.id, collection_name=knowledge_base.id
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
|
||||
)
|
||||
failed_files.append({"file_id": file.id, "error": str(e)})
|
||||
continue
|
||||
|
||||
if failed_files:
|
||||
log.warning(
|
||||
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
|
||||
)
|
||||
for failed in failed_files:
|
||||
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
|
||||
|
||||
log.info(
|
||||
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
############################
|
||||
# GetKnowledgeById
|
||||
|
@ -17,7 +17,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("/ef")
|
||||
async def get_embeddings(request: Request):
|
||||
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
|
||||
return {"result": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL]("hello world")}
|
||||
|
||||
|
||||
############################
|
||||
@ -57,7 +57,7 @@ async def add_memory(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {"created_at": memory.created_at},
|
||||
@ -84,7 +84,7 @@ async def query_memory(
|
||||
):
|
||||
results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||
vectors=[request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](form_data.content, user=user)],
|
||||
limit=form_data.k,
|
||||
)
|
||||
|
||||
@ -107,7 +107,7 @@ async def reset_memory_from_vector_db(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {
|
||||
@ -166,7 +166,7 @@ async def update_memory_by_id(
|
||||
{
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||
"vector": request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL](
|
||||
memory.content, user=user
|
||||
),
|
||||
"metadata": {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,338 @@
|
||||
from test.util.abstract_integration_test import AbstractPostgresTest
|
||||
from test.util.mock_user import mock_webui_user
|
||||
|
||||
class TestRagConfig(AbstractPostgresTest):
|
||||
BASE_PATH = "/api/v1/config"
|
||||
|
||||
def setup_class(cls):
|
||||
super().setup_class()
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
|
||||
cls.knowledges = Knowledges
|
||||
|
||||
def setup_method(self):
|
||||
super().setup_method()
|
||||
# Insert a knowledge base with default settings
|
||||
self.knowledges.insert_new_knowledge(
|
||||
id="1",
|
||||
name="Default KB",
|
||||
rag_config={
|
||||
"DEFAULT_RAG_SETTINGS": True,
|
||||
"TEMPLATE": "default-template",
|
||||
"TOP_K": 5,
|
||||
},
|
||||
)
|
||||
# Insert a knowledge base with custom RAG config
|
||||
self.knowledges.insert_new_knowledge(
|
||||
id="2",
|
||||
name="Custom KB",
|
||||
rag_config={
|
||||
"DEFAULT_RAG_SETTINGS": False,
|
||||
"TEMPLATE": "custom-template",
|
||||
"TOP_K": 10,
|
||||
"web": {
|
||||
"ENABLE_WEB_SEARCH": True,
|
||||
"WEB_SEARCH_ENGINE": "custom-engine"
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_get_rag_config_default(self):
|
||||
# Should return default config for knowledge base with DEFAULT_RAG_SETTINGS True
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url(""),
|
||||
json={"knowledge_id": "1"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["RAG_TEMPLATE"] == "default-template"
|
||||
assert data["TOP_K"] == 5
|
||||
assert data["DEFAULT_RAG_SETTINGS"] is True
|
||||
|
||||
def test_get_rag_config_individual(self):
|
||||
# Should return custom config for knowledge base with DEFAULT_RAG_SETTINGS False
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url(""),
|
||||
json={"knowledge_id": "2"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["RAG_TEMPLATE"] == "custom-template"
|
||||
assert data["TOP_K"] == 10
|
||||
assert data["DEFAULT_RAG_SETTINGS"] is False
|
||||
assert data["web"]["ENABLE_WEB_SEARCH"] is True
|
||||
assert data["web"]["WEB_SEARCH_ENGINE"] == "custom-engine"
|
||||
|
||||
def test_get_rag_config_unauthorized(self):
|
||||
# Should return 401 if not authenticated
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url(""),
|
||||
json={"knowledge_id": "1"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_update_rag_config_default(self):
|
||||
# Should update the global config for knowledge base with DEFAULT_RAG_SETTINGS True
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={
|
||||
"knowledge_id": "1",
|
||||
"RAG_TEMPLATE": "updated-template",
|
||||
"TOP_K": 42,
|
||||
"ENABLE_RAG_HYBRID_SEARCH": False,
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["RAG_TEMPLATE"] == "updated-template"
|
||||
assert data["TOP_K"] == 42
|
||||
assert data["ENABLE_RAG_HYBRID_SEARCH"] is False
|
||||
|
||||
def test_update_rag_config_individual(self):
|
||||
# Should update the config for knowledge base with DEFAULT_RAG_SETTINGS False
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={
|
||||
"knowledge_id": "2",
|
||||
"TEMPLATE": "individual-updated",
|
||||
"TOP_K": 99,
|
||||
"web": {
|
||||
"ENABLE_WEB_SEARCH": False,
|
||||
"WEB_SEARCH_ENGINE": "updated-engine"
|
||||
}
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["TEMPLATE"] == "individual-updated"
|
||||
assert data["TOP_K"] == 99
|
||||
assert data["web"]["ENABLE_WEB_SEARCH"] is False
|
||||
assert data["web"]["WEB_SEARCH_ENGINE"] == "updated-engine"
|
||||
|
||||
def test_update_reranking_model_and_states_individual(self):
|
||||
# Simulate app state for reranking models
|
||||
app = self.fast_api_client.app
|
||||
app.state.rf = {}
|
||||
app.state.config.LOADED_RERANKING_MODELS = {"": [], "external": []}
|
||||
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": [], "external": []}
|
||||
|
||||
# Update individual config with new reranking model
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={
|
||||
"knowledge_id": "2",
|
||||
"RAG_RERANKING_MODEL": "",
|
||||
"RAG_RERANKING_ENGINE": "",
|
||||
"RAG_EXTERNAL_RERANKER_URL": "",
|
||||
"RAG_EXTERNAL_RERANKER_API_KEY": "",
|
||||
"ENABLE_RAG_HYBRID_SEARCH": True,
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Model should be in loaded and downloaded models
|
||||
loaded = app.state.config.LOADED_RERANKING_MODELS[""]
|
||||
downloaded = app.state.config.DOWNLOADED_RERANKING_MODELS[""]
|
||||
assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in loaded)
|
||||
assert "BBAI/bge-reranker-v2-m3" in downloaded
|
||||
assert "BBAI/bge-reranker-v2-m3" in app.state.rf
|
||||
|
||||
def test_update_reranking_model_and_states_default(self):
|
||||
# Simulate app state for reranking models
|
||||
app = self.fast_api_client.app
|
||||
app.state.rf = {}
|
||||
app.state.config.LOADED_RERANKING_MODELS = {"": [], "external": []}
|
||||
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": [], "external": []}
|
||||
|
||||
# Update default config with new reranking model
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={
|
||||
"knowledge_id": "1",
|
||||
"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3",
|
||||
"RAG_RERANKING_ENGINE": "",
|
||||
"RAG_EXTERNAL_RERANKER_URL": "",
|
||||
"RAG_EXTERNAL_RERANKER_API_KEY": "",
|
||||
"ENABLE_RAG_HYBRID_SEARCH": True,
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
loaded = app.state.config.LOADED_RERANKING_MODELS[""]
|
||||
downloaded = app.state.config.DOWNLOADED_RERANKING_MODELS[""]
|
||||
assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in loaded)
|
||||
assert "BBAI/bge-reranker-v2-m3" in downloaded
|
||||
assert "BBAI/bge-reranker-v2-m3" in app.state.rf
|
||||
|
||||
def test_update_rag_config_unauthorized(self):
|
||||
# Should return 401 if not authenticated
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={"knowledge_id": "1", "RAG_TEMPLATE": "should-not-update"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_reranking_model_freed_only_if_not_in_use_elsewhere(self):
|
||||
"""
|
||||
Test that the reranking model is only deleted from state if no other knowledge base is using it.
|
||||
"""
|
||||
app = self.fast_api_client.app
|
||||
app.state.rf = {"rerank-model-shared": object()}
|
||||
app.state.config.LOADED_RERANKING_MODELS = {"": [{"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3"}]}
|
||||
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": ["BBAI/bge-reranker-v2-m3"]}
|
||||
|
||||
# Patch is_model_in_use_elsewhere to simulate model still in use
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("open_webui.models.knowledge.Knowledges.is_model_in_use_elsewhere", return_value=True):
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={
|
||||
"knowledge_id": "2",
|
||||
"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3",
|
||||
"RAG_RERANKING_ENGINE": "",
|
||||
"ENABLE_RAG_HYBRID_SEARCH": False,
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Model should NOT be deleted from state
|
||||
assert "rerank-model-shared" in app.state.rf
|
||||
assert any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in app.state.config.LOADED_RERANKING_MODELS[""])
|
||||
|
||||
# Now simulate model NOT in use elsewhere
|
||||
app.state.rf = {"rerank-model-shared": object()}
|
||||
app.state.config.LOADED_RERANKING_MODELS = {"": [{"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3"}]}
|
||||
app.state.config.DOWNLOADED_RERANKING_MODELS = {"": ["BBAI/bge-reranker-v2-m3"]}
|
||||
|
||||
with patch("open_webui.models.knowledge.Knowledges.is_model_in_use_elsewhere", return_value=False):
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update"),
|
||||
json={
|
||||
"knowledge_id": "2",
|
||||
"RAG_RERANKING_MODEL": "BBAI/bge-reranker-v2-m3",
|
||||
"RAG_RERANKING_ENGINE": "",
|
||||
"ENABLE_RAG_HYBRID_SEARCH": False,
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Model should be deleted from state
|
||||
assert "rerank-model-shared" not in app.state.rf
|
||||
assert not any(m["RAG_RERANKING_MODEL"] == "BBAI/bge-reranker-v2-m3" for m in app.state.config.LOADED_RERANKING_MODELS[""])
|
||||
|
||||
def test_get_embedding_config_default(self):
|
||||
# Should return default embedding config for knowledge base with DEFAULT_RAG_SETTINGS True
|
||||
# First, add embedding config to the default KB
|
||||
self.knowledges.update_rag_config_by_id(
|
||||
id="1",
|
||||
rag_config={
|
||||
"DEFAULT_RAG_SETTINGS": True,
|
||||
"embedding_engine": "",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"embedding_batch_size": 1,
|
||||
"openai_config": {"url": "https://api.openai.com", "key": "default-key"},
|
||||
"ollama_config": {"url": "http://localhost:11434", "key": "ollama-key"},
|
||||
}
|
||||
)
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/embedding"),
|
||||
json={"knowledge_id": "1"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["embedding_engine"] == ""
|
||||
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert data["embedding_batch_size"] == 1
|
||||
assert data["openai_config"]["url"] == "https://api.openai.com"
|
||||
assert data["openai_config"]["key"] == "default-key"
|
||||
assert data["ollama_config"]["url"] == "http://localhost:11434"
|
||||
assert data["ollama_config"]["key"] == "ollama-key"
|
||||
|
||||
def test_get_embedding_config_individual(self):
|
||||
# Should return custom embedding config for knowledge base with DEFAULT_RAG_SETTINGS False
|
||||
self.knowledges.update_rag_config_by_id(
|
||||
id="2",
|
||||
rag_config={
|
||||
"DEFAULT_RAG_SETTINGS": False,
|
||||
"embedding_engine": "",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"embedding_batch_size": 2,
|
||||
"openai_config": {"url": "https://custom.openai.com", "key": "custom-key"},
|
||||
"ollama_config": {"url": "http://custom-ollama:11434", "key": "custom-ollama-key"},
|
||||
}
|
||||
)
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/embedding"),
|
||||
json={"knowledge_id": "2"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["embedding_engine"] == ""
|
||||
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert data["embedding_batch_size"] == 2
|
||||
assert data["openai_config"]["url"] == "https://custom.openai.com"
|
||||
assert data["openai_config"]["key"] == "custom-key"
|
||||
assert data["ollama_config"]["url"] == "http://custom-ollama:11434"
|
||||
assert data["ollama_config"]["key"] == "custom-ollama-key"
|
||||
|
||||
def test_update_embedding_config_default(self):
|
||||
# Should update the global embedding config for knowledge base with DEFAULT_RAG_SETTINGS True
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/embedding/update"),
|
||||
json={
|
||||
"knowledge_id": "1",
|
||||
"embedding_engine": "",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"embedding_batch_size": 4,
|
||||
"openai_config": {"url": "https://api.openai.com/v2", "key": "updated-key"},
|
||||
"ollama_config": {"url": "http://localhost:11434", "key": "ollama-key"},
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["embedding_engine"] == ""
|
||||
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert data["embedding_batch_size"] == 4
|
||||
assert data["openai_config"]["url"] == "https://api.openai.com/v2"
|
||||
assert data["openai_config"]["key"] == "updated-key"
|
||||
|
||||
def test_update_embedding_config_individual(self):
|
||||
# Should update the embedding config for knowledge base with DEFAULT_RAG_SETTINGS False
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/embedding/update"),
|
||||
json={
|
||||
"knowledge_id": "2",
|
||||
"embedding_engine": "",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"embedding_batch_size": 8,
|
||||
"openai_config": {"url": "https://custom.openai.com/v2", "key": "custom-key-2"},
|
||||
"ollama_config": {"url": "http://custom-ollama:11434/v2", "key": "custom-ollama-key-2"},
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] is True
|
||||
assert data["embedding_engine"] == ""
|
||||
assert data["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert data["embedding_batch_size"] == 8
|
||||
assert data["openai_config"]["url"] == "https://custom.openai.com/v2"
|
||||
assert data["openai_config"]["key"] == "custom-key-2"
|
||||
assert data["ollama_config"]["url"] == "http://custom-ollama:11434/v2"
|
||||
assert data["ollama_config"]["key"] == "custom-ollama-key-2"
|
@ -602,7 +602,7 @@ async def chat_image_generation_handler(
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
request: Request, body: dict, user: UserModel
|
||||
request: Request, body: dict, user: UserModel, model_knowledge
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
sources = []
|
||||
|
||||
@ -640,6 +640,21 @@ async def chat_completion_files_handler(
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
try:
|
||||
# check if individual rag config is used
|
||||
rag_config = {}
|
||||
if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_config = model_knowledge[0].get("rag_config")
|
||||
|
||||
k=rag_config.get("TOP_K", request.app.state.config.TOP_K)
|
||||
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
|
||||
reranking_function=request.app.state.rf[reranking_model] if reranking_model else None
|
||||
k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
|
||||
r=rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
||||
hybrid_bm25_weight=rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
|
||||
hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
|
||||
full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT)
|
||||
embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
# Offload get_sources_from_files to a separate thread
|
||||
loop = asyncio.get_running_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
@ -649,16 +664,16 @@ async def chat_completion_files_handler(
|
||||
request=request,
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||
query, prefix=prefix, user=user
|
||||
),
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
full_context=request.app.state.config.RAG_FULL_CONTEXT,
|
||||
user=user,
|
||||
ef=request.app.state.EMBEDDING_FUNCTION,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
hybrid_search=hybrid_search,
|
||||
full_context=full_context,
|
||||
embedding_model=embedding_model,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@ -917,7 +932,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
log.exception(e)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
||||
form_data, flags = await chat_completion_files_handler(request, form_data, user, model_knowledge)
|
||||
sources.extend(flags.get("sources", []))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@ -958,20 +973,24 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
||||
)
|
||||
|
||||
# Adjusted RAG template step to use knowledge-base-specific configuration
|
||||
rag_template_config = request.app.state.config.RAG_TEMPLATE
|
||||
|
||||
if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||
rag_template_config = model_knowledge[0].get("rag_config").get(
|
||||
"RAG_TEMPLATE", request.app.state.config.RAG_TEMPLATE
|
||||
)
|
||||
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model.get("owned_by") == "ollama":
|
||||
form_data["messages"] = prepend_to_first_user_message_content(
|
||||
rag_template(
|
||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
rag_template(rag_template_config, context_string, prompt),
|
||||
form_data["messages"],
|
||||
)
|
||||
else:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
rag_template(
|
||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||
),
|
||||
rag_template(rag_template_config, context_string, prompt),
|
||||
form_data["messages"],
|
||||
)
|
||||
|
||||
|
@ -1,12 +1,15 @@
|
||||
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
||||
|
||||
export const uploadFile = async (token: string, file: File, metadata?: object | null) => {
|
||||
export const uploadFile = async (token: string, file: File, metadata?: object | null, knowledge_id?: string) => {
|
||||
const data = new FormData();
|
||||
data.append('file', file);
|
||||
if (metadata) {
|
||||
data.append('metadata', JSON.stringify(metadata));
|
||||
}
|
||||
|
||||
if (knowledge_id) {
|
||||
data.append('knowledge_id', knowledge_id);
|
||||
}
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, {
|
||||
|
@ -4,7 +4,8 @@ export const createNewKnowledge = async (
|
||||
token: string,
|
||||
name: string,
|
||||
description: string,
|
||||
accessControl: null | object
|
||||
accessControl: null | object,
|
||||
rag_config: null | object
|
||||
) => {
|
||||
let error = null;
|
||||
|
||||
@ -18,7 +19,8 @@ export const createNewKnowledge = async (
|
||||
body: JSON.stringify({
|
||||
name: name,
|
||||
description: description,
|
||||
access_control: accessControl
|
||||
access_control: accessControl,
|
||||
rag_config: rag_config
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
@ -137,6 +139,7 @@ type KnowledgeUpdateForm = {
|
||||
description?: string;
|
||||
data?: object;
|
||||
access_control?: null | object;
|
||||
rag_config?: object;
|
||||
};
|
||||
|
||||
export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => {
|
||||
@ -153,7 +156,8 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl
|
||||
name: form?.name ? form.name : undefined,
|
||||
description: form?.description ? form.description : undefined,
|
||||
data: form?.data ? form.data : undefined,
|
||||
access_control: form.access_control
|
||||
access_control: form.access_control,
|
||||
rag_config: form?.rag_config ? form.rag_config : undefined
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
@ -373,3 +377,31 @@ export const reindexKnowledgeFiles = async (token: string) => {
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const reindexSpecificKnowledgeFiles = async (token: string, id: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/reindex/${id}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err.detail;
|
||||
console.error(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
@ -1,14 +1,17 @@
|
||||
import { RETRIEVAL_API_BASE_URL } from '$lib/constants';
|
||||
|
||||
export const getRAGConfig = async (token: string) => {
|
||||
export const getRAGConfig = async (token: string, collectionForm?: CollectionForm) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/config`, {
|
||||
method: 'GET',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
},
|
||||
body: JSON.stringify(
|
||||
collectionForm ? {collectionForm: collectionForm} : {}
|
||||
)
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
@ -57,6 +60,7 @@ type RAGConfigForm = {
|
||||
content_extraction?: ContentExtractConfigForm;
|
||||
web_loader_ssl_verification?: boolean;
|
||||
youtube?: YoutubeConfigForm;
|
||||
knowledge_id?: string;
|
||||
};
|
||||
|
||||
export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => {
|
||||
@ -152,15 +156,18 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getEmbeddingConfig = async (token: string) => {
|
||||
export const getEmbeddingConfig = async (token: string, collectionForm?: CollectionForm) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RETRIEVAL_API_BASE_URL}/embedding`, {
|
||||
method: 'GET',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
},
|
||||
body: JSON.stringify(
|
||||
collectionForm ? {collectionForm: collectionForm} : {}
|
||||
)
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
@ -196,6 +203,7 @@ type EmbeddingModelUpdateForm = {
|
||||
embedding_engine: string;
|
||||
embedding_model: string;
|
||||
embedding_batch_size?: number;
|
||||
knowledge_id?: string;
|
||||
};
|
||||
|
||||
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -49,6 +49,7 @@
|
||||
import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte';
|
||||
import LockClosed from '$lib/components/icons/LockClosed.svelte';
|
||||
import AccessControlModal from '../common/AccessControlModal.svelte';
|
||||
import RagConfigModal from '../common/RagConfigModal.svelte';
|
||||
|
||||
let largeScreen = true;
|
||||
|
||||
@ -73,7 +74,8 @@
|
||||
let showAddTextContentModal = false;
|
||||
let showSyncConfirmModal = false;
|
||||
let showAccessControlModal = false;
|
||||
|
||||
let showRagConfigModal = false;
|
||||
|
||||
let inputFiles = null;
|
||||
|
||||
let filteredItems = [];
|
||||
@ -122,7 +124,7 @@
|
||||
return file;
|
||||
};
|
||||
|
||||
const uploadFileHandler = async (file) => {
|
||||
const uploadFileHandler = async (file, knowledgeId) => {
|
||||
console.log(file);
|
||||
|
||||
const tempItemId = uuidv4();
|
||||
@ -173,7 +175,7 @@
|
||||
};
|
||||
}
|
||||
|
||||
const uploadedFile = await uploadFile(localStorage.token, file, metadata).catch((e) => {
|
||||
const uploadedFile = await uploadFile(localStorage.token, file, metadata, knowledgeId).catch((e) => {
|
||||
toast.error(`${e}`);
|
||||
return null;
|
||||
});
|
||||
@ -264,7 +266,7 @@
|
||||
const file = await entry.getFile();
|
||||
const fileWithPath = new File([file], entryPath, { type: file.type });
|
||||
|
||||
await uploadFileHandler(fileWithPath);
|
||||
await uploadFileHandler(fileWithPath, id);
|
||||
uploadedFiles++;
|
||||
updateProgress();
|
||||
} else if (entry.kind === 'directory') {
|
||||
@ -326,7 +328,7 @@
|
||||
const relativePath = file.webkitRelativePath || file.name;
|
||||
const fileWithPath = new File([file], relativePath, { type: file.type });
|
||||
|
||||
await uploadFileHandler(fileWithPath);
|
||||
await uploadFileHandler(fileWithPath, id);
|
||||
uploadedFiles++;
|
||||
updateProgress();
|
||||
}
|
||||
@ -456,7 +458,8 @@
|
||||
...knowledge,
|
||||
name: knowledge.name,
|
||||
description: knowledge.description,
|
||||
access_control: knowledge.access_control
|
||||
access_control: knowledge.access_control,
|
||||
rag_config: knowledge.rag_config
|
||||
}).catch((e) => {
|
||||
toast.error(`${e}`);
|
||||
});
|
||||
@ -524,7 +527,7 @@
|
||||
|
||||
if (inputFiles && inputFiles.length > 0) {
|
||||
for (const file of inputFiles) {
|
||||
await uploadFileHandler(file);
|
||||
await uploadFileHandler(file, id);
|
||||
}
|
||||
} else {
|
||||
toast.error($i18n.t(`File not found.`));
|
||||
@ -579,6 +582,18 @@
|
||||
|
||||
if (res) {
|
||||
knowledge = res;
|
||||
knowledge.rag_config.ALLOWED_FILE_EXTENSIONS = (config?.ALLOWED_FILE_EXTENSIONS ?? []).join(', ');
|
||||
|
||||
knowledge.rag_config.DOCLING_PICTURE_DESCRIPTION_LOCAL = JSON.stringify(
|
||||
config.DOCLING_PICTURE_DESCRIPTION_LOCAL ?? {},
|
||||
null,
|
||||
2
|
||||
);
|
||||
knowledge.rag_config.DOCLING_PICTURE_DESCRIPTION_API = JSON.stringify(
|
||||
config.DOCLING_PICTURE_DESCRIPTION_API ?? {},
|
||||
null,
|
||||
2
|
||||
);
|
||||
} else {
|
||||
goto('/workspace/knowledge');
|
||||
}
|
||||
@ -643,7 +658,7 @@
|
||||
bind:show={showAddTextContentModal}
|
||||
on:submit={(e) => {
|
||||
const file = createFileFromText(e.detail.name, e.detail.content);
|
||||
uploadFileHandler(file);
|
||||
uploadFileHandler(file, id);
|
||||
}}
|
||||
/>
|
||||
|
||||
@ -656,7 +671,7 @@
|
||||
on:change={async () => {
|
||||
if (inputFiles && inputFiles.length > 0) {
|
||||
for (const file of inputFiles) {
|
||||
await uploadFileHandler(file);
|
||||
await uploadFileHandler(file, id);
|
||||
}
|
||||
|
||||
inputFiles = null;
|
||||
@ -682,6 +697,16 @@
|
||||
}}
|
||||
accessRoles={['read', 'write']}
|
||||
/>
|
||||
{#if knowledge.rag_config.DEFAULT_RAG_SETTINGS == false}
|
||||
<RagConfigModal
|
||||
bind:show={showRagConfigModal}
|
||||
RAGConfig={knowledge.rag_config}
|
||||
knowledgeId={knowledge.id}
|
||||
on:update={(e) => {
|
||||
knowledge.rag_config = e.detail; // sync updated config
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
<div class="w-full mb-2.5">
|
||||
<div class=" flex w-full">
|
||||
<div class="flex-1">
|
||||
@ -713,6 +738,33 @@
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if knowledge.rag_config.DEFAULT_RAG_SETTINGS == false}
|
||||
<button
|
||||
class="bg-gray-50 hover:bg-gray-100 text-black dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white transition px-2 py-1 rounded-full flex gap-1 items-center"
|
||||
style="width: 150px;"
|
||||
type="button"
|
||||
on:click={() => {
|
||||
showRagConfigModal = true;
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 16 16"
|
||||
fill="currentColor"
|
||||
class="w-4 h-4"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M6.955 1.45A.5.5 0 0 1 7.452 1h1.096a.5.5 0 0 1 .497.45l.17 1.699c.484.12.94.312 1.356.562l1.321-1.081a.5.5 0 0 1 .67.033l.774.775a.5.5 0 0 1 .034.67l-1.08 1.32c.25.417.44.873.561 1.357l1.699.17a.5.5 0 0 1 .45.497v1.096a.5.5 0 0 1-.45.497l-1.699.17c-.12.484-.312.94-.562 1.356l1.082 1.322a.5.5 0 0 1-.034.67l-.774.774a.5.5 0 0 1-.67.033l-1.322-1.08c-.416.25-.872.44-1.356.561l-.17 1.699a.5.5 0 0 1-.497.45H7.452a.5.5 0 0 1-.497-.45l-.17-1.699a4.973 4.973 0 0 1-1.356-.562L4.108 13.37a.5.5 0 0 1-.67-.033l-.774-.775a.5.5 0 0 1-.034-.67l1.08-1.32a4.971 4.971 0 0 1-.561-1.357l-1.699-.17A.5.5 0 0 1 1 8.548V7.452a.5.5 0 0 1 .45-.497l1.699-.17c.12-.484.312-.94.562-1.356L2.629 4.107a.5.5 0 0 1 .034-.67l.774-.774a.5.5 0 0 1 .67-.033L5.43 3.71a4.97 4.97 0 0 1 1.356-.561l.17-1.699ZM6 8c0 .538.212 1.026.558 1.385l.057.057a2 2 0 0 0 2.828-2.828l-.058-.056A2 2 0 0 0 6 8Z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
<div class="text-sm font-medium shrink-0">
|
||||
{$i18n.t('RAG Config')}
|
||||
</div>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex w-full px-1">
|
||||
|
1205
src/lib/components/workspace/common/RagConfigModal.svelte
Normal file
1205
src/lib/components/workspace/common/RagConfigModal.svelte
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user