mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into playwright
# Conflicts: # backend/open_webui/config.py # backend/open_webui/main.py # backend/open_webui/retrieval/web/utils.py # backend/open_webui/routers/retrieval.py # backend/open_webui/utils/middleware.py # pyproject.toml
This commit is contained in:
@@ -616,7 +616,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
form_data["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
@@ -766,17 +772,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
doc_metadata = source.get("metadata")
|
||||
doc_source_id = None
|
||||
|
||||
if doc_metadata:
|
||||
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
else:
|
||||
# If there is no source_id, then do not include the source_id tag
|
||||
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
|
||||
context_string += f"<source><source_id>{doc_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -1149,6 +1145,46 @@ async def process_chat_response(
|
||||
|
||||
return content.strip()
|
||||
|
||||
def convert_content_blocks_to_messages(content_blocks):
|
||||
messages = []
|
||||
|
||||
temp_blocks = []
|
||||
for idx, block in enumerate(content_blocks):
|
||||
if block["type"] == "tool_calls":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(temp_blocks),
|
||||
"tool_calls": block.get("content"),
|
||||
}
|
||||
)
|
||||
|
||||
results = block.get("results", [])
|
||||
|
||||
for result in results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
)
|
||||
temp_blocks = []
|
||||
else:
|
||||
temp_blocks.append(block)
|
||||
|
||||
if temp_blocks:
|
||||
content = serialize_content_blocks(temp_blocks)
|
||||
if content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def tag_content_handler(content_type, tags, content, content_blocks):
|
||||
end_flag = False
|
||||
|
||||
@@ -1540,7 +1576,6 @@ async def process_chat_response(
|
||||
|
||||
results = []
|
||||
for tool_call in response_tool_calls:
|
||||
print("\n\n" + str(tool_call) + "\n\n")
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
|
||||
@@ -1606,23 +1641,10 @@ async def process_chat_response(
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"tools": form_data["tools"],
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
"tool_calls": response_tool_calls,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
for result in results
|
||||
],
|
||||
*convert_content_blocks_to_messages(content_blocks),
|
||||
],
|
||||
},
|
||||
user,
|
||||
@@ -1671,6 +1693,9 @@ async def process_chat_response(
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": code,
|
||||
"session_id": metadata.get(
|
||||
"session_id", None
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1699,10 +1724,12 @@ async def process_chat_response(
|
||||
"stdout": "Code interpreter engine not configured."
|
||||
}
|
||||
|
||||
log.debug(f"Code interpreter output: {output}")
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
if stdout:
|
||||
if isinstance(stdout, str):
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
@@ -1734,7 +1761,7 @@ async def process_chat_response(
|
||||
|
||||
result = output.get("result", "")
|
||||
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
@@ -1784,6 +1811,8 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
print(content_blocks, serialize_content_blocks(content_blocks))
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
|
||||
Reference in New Issue
Block a user