Merge branch 'dev' into fix-db-order

This commit is contained in:
Timothy Jaeryang Baek
2025-03-26 20:55:42 -07:00
committed by GitHub
76 changed files with 1740 additions and 261 deletions

View File

@@ -1685,6 +1685,11 @@ BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
)
RAG_TOP_K_RERANKER = PersistentConfig(
"RAG_TOP_K_RERANKER",
"rag.top_k_reranker",
int(os.environ.get("RAG_TOP_K_RERANKER", "3"))
)
RAG_RELEVANCE_THRESHOLD = PersistentConfig(
"RAG_RELEVANCE_THRESHOLD",
"rag.relevance_threshold",

View File

@@ -414,13 +414,12 @@ if OFFLINE_MODE:
####################################
# AUDIT LOGGING
####################################
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:

View File

@@ -223,6 +223,9 @@ async def generate_function_chat_completion(
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__chat_id__": metadata.get("chat_id", None),
"__session_id__": metadata.get("session_id", None),
"__message_id__": metadata.get("message_id", None),
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,

View File

@@ -191,6 +191,7 @@ from open_webui.config import (
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
RAG_TOP_K,
RAG_TOP_K_RERANKER,
RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES,
@@ -552,6 +553,7 @@ app.state.FUNCTIONS = {}
app.state.config.TOP_K = RAG_TOP_K
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT

View File

@@ -105,7 +105,7 @@ class TikaLoader:
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]

View File

@@ -106,6 +106,7 @@ def query_doc_with_hybrid_search(
embedding_function,
k: int,
reranking_function,
k_reranker: int,
r: float,
) -> dict:
try:
@@ -128,7 +129,7 @@ def query_doc_with_hybrid_search(
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
top_n=k_reranker,
reranking_function=reranking_function,
r_score=r,
)
@@ -138,10 +139,20 @@ def query_doc_with_hybrid_search(
)
result = compression_retriever.invoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
sorted_items = sorted(zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
"distances": [distances],
"documents": [documents],
"metadatas": [metadatas],
}
log.info(
@@ -264,6 +275,7 @@ def query_collection_with_hybrid_search(
embedding_function,
k: int,
reranking_function,
k_reranker: int,
r: float,
) -> dict:
results = []
@@ -277,6 +289,7 @@ def query_collection_with_hybrid_search(
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
)
results.append(result)
@@ -290,10 +303,8 @@ def query_collection_with_hybrid_search(
raise Exception(
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k)
def get_embedding_function(
embedding_engine,
embedding_model,
@@ -337,6 +348,7 @@ def get_sources_from_files(
embedding_function,
k,
reranking_function,
k_reranker,
r,
hybrid_search,
full_context=False,
@@ -453,6 +465,7 @@ def get_sources_from_files(
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
)
except Exception as e:

View File

@@ -172,12 +172,19 @@ class ChromaClient:
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
try:
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
log.debug(
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
pass
def reset(self):
# Resets the database. This will delete all collections and item entries.

View File

@@ -2,6 +2,8 @@ import json
import logging
from typing import Optional
from open_webui.socket.main import get_event_emitter
from open_webui.models.chats import (
ChatForm,
ChatImportForm,
@@ -372,6 +374,107 @@ async def update_chat_by_id(
)
############################
# UpdateChatMessageById
############################
class MessageForm(BaseModel):
content: str
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
id,
message_id,
{
"content": form_data.content,
},
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
},
False,
)
if event_emitter:
await event_emitter(
{
"type": "chat:message",
"data": {
"chat_id": id,
"message_id": message_id,
"content": form_data.content,
},
}
)
return ChatResponse(**chat.model_dump())
############################
# SendChatMessageEventById
############################
class EventForm(BaseModel):
type: str
data: dict
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
}
)
try:
if event_emitter:
await event_emitter(form_data.model_dump())
else:
return False
return True
except:
return False
############################
# DeleteChatById
############################

View File

@@ -719,6 +719,7 @@ async def get_query_settings(request: Request, user=Depends(get_admin_user)):
"status": True,
"template": request.app.state.config.RAG_TEMPLATE,
"k": request.app.state.config.TOP_K,
"k_reranker": request.app.state.config.TOP_K_RERANKER,
"r": request.app.state.config.RELEVANCE_THRESHOLD,
"hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -726,6 +727,7 @@ async def get_query_settings(request: Request, user=Depends(get_admin_user)):
class QuerySettingsForm(BaseModel):
k: Optional[int] = None
k_reranker: Optional[int] = None
r: Optional[float] = None
template: Optional[str] = None
hybrid: Optional[bool] = None
@@ -737,6 +739,7 @@ async def update_query_settings(
):
request.app.state.config.RAG_TEMPLATE = form_data.template
request.app.state.config.TOP_K = form_data.k if form_data.k else 4
request.app.state.config.TOP_K_RERANKER = form_data.k_reranker or 4
request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
@@ -747,6 +750,7 @@ async def update_query_settings(
"status": True,
"template": request.app.state.config.RAG_TEMPLATE,
"k": request.app.state.config.TOP_K,
"k_reranker": request.app.state.config.TOP_K_RERANKER,
"r": request.app.state.config.RELEVANCE_THRESHOLD,
"hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -1495,6 +1499,7 @@ class QueryDocForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = None
k_reranker: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
@@ -1515,6 +1520,7 @@ def query_doc_handler(
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER,
r=(
form_data.r
if form_data.r
@@ -1543,6 +1549,7 @@ class QueryCollectionsForm(BaseModel):
collection_names: list[str]
query: str
k: Optional[int] = None
k_reranker: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
@@ -1563,6 +1570,7 @@ def query_collection_handler(
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER,
r=(
form_data.r
if form_data.r

View File

@@ -269,11 +269,19 @@ async def disconnect(sid):
# print(f"Unknown session ID {sid} disconnected")
def get_event_emitter(request_info):
def get_event_emitter(request_info, update_db=True):
async def __event_emitter__(event_data):
user_id = request_info["user_id"]
session_ids = list(
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
set(
USER_POOL.get(user_id, [])
+ (
[request_info.get("session_id")]
if request_info.get("session_id")
else []
)
)
)
for session_id in session_ids:
@@ -287,40 +295,41 @@ def get_event_emitter(request_info):
to=session_id,
)
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
if update_db:
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
return __event_emitter__

View File

@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
event_caller = extra_params["__event_call__"]
metadata = extra_params["__metadata__"]
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
@@ -189,17 +192,33 @@ async def chat_completion_tools_handler(
tool_function_params = tool_call.get("parameters", {})
try:
spec = tools[tool_function_name].get("spec", {})
tool = tools[tool_function_name]
spec = tool.get("spec", {})
allowed_params = (
spec.get("parameters", {}).get("properties", {}).keys()
)
tool_function = tools[tool_function_name]["callable"]
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in allowed_params
}
tool_output = await tool_function(**tool_function_params)
if tool.get("direct", False):
tool_output = await tool_function(**tool_function_params)
else:
tool_output = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get("session_id", None),
},
}
)
except Exception as e:
tool_output = str(e)
@@ -565,6 +584,7 @@ async def chat_completion_files_handler(
),
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
@@ -764,12 +784,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_specs = form_data.get("tool_specs", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_specs=}")
tools_dict = {}
if tool_ids:
# If tool_ids field is present, then get the tools
tools = get_tools(
tools_dict = get_tools(
request,
tool_ids,
user,
@@ -780,20 +806,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
log.info(f"{tools_dict=}")
if tool_specs:
for tool in tool_specs:
callable = tool.pop("callable", None)
tools_dict[tool["name"]] = {
"direct": True,
"callable": callable,
"spec": tool,
}
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
@@ -812,7 +848,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
for source_idx, source in enumerate(sources):
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string += f"<source><source_id>{source_idx + 1}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -1079,8 +1115,6 @@ async def process_chat_response(
for filter_id in get_sorted_filter_ids(model)
]
print(f"{filter_functions=}")
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@@ -1560,7 +1594,9 @@ async def process_chat_response(
value = delta.get("content")
reasoning_content = delta.get("reasoning_content")
reasoning_content = delta.get(
"reasoning_content"
) or delta.get("reasoning")
if reasoning_content:
if (
not content_blocks
@@ -1774,9 +1810,25 @@ async def process_chat_response(
for k, v in tool_function_params.items()
if k in allowed_params
}
tool_result = await tool_function(
**tool_function_params
)
if tool.get("direct", False):
tool_result = await tool_function(
**tool_function_params
)
else:
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get(
"session_id", None
),
},
}
)
except Exception as e:
tool_result = str(e)

View File

@@ -1,6 +1,9 @@
import inspect
import logging
import re
import inspect
import uuid
from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial