mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into playwright
# Conflicts: # backend/open_webui/config.py # backend/open_webui/main.py # backend/open_webui/retrieval/web/utils.py # backend/open_webui/routers/retrieval.py # backend/open_webui/utils/middleware.py # pyproject.toml
This commit is contained in:
@@ -7,14 +7,17 @@ from typing import Any, Optional
|
||||
import random
|
||||
import json
|
||||
import inspect
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from fastapi import Request, status
|
||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
|
||||
from open_webui.socket.main import (
|
||||
sio,
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
@@ -57,16 +60,127 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def generate_direct_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: Any,
|
||||
models: dict,
|
||||
):
|
||||
print("generate_direct_chat_completion")
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
user_id = metadata.get("user_id")
|
||||
session_id = metadata.get("session_id")
|
||||
request_id = str(uuid.uuid4()) # Generate a unique request ID
|
||||
|
||||
event_caller = get_event_call(metadata)
|
||||
|
||||
channel = f"{user_id}:{session_id}:{request_id}"
|
||||
|
||||
if form_data.get("stream"):
|
||||
q = asyncio.Queue()
|
||||
|
||||
async def message_listener(sid, data):
|
||||
"""
|
||||
Handle received socket messages and push them into the queue.
|
||||
"""
|
||||
await q.put(data)
|
||||
|
||||
# Register the listener
|
||||
sio.on(channel, message_listener)
|
||||
|
||||
# Start processing chat completion in background
|
||||
res = await event_caller(
|
||||
{
|
||||
"type": "request:chat:completion",
|
||||
"data": {
|
||||
"form_data": form_data,
|
||||
"model": models[form_data["model"]],
|
||||
"channel": channel,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
print("res", res)
|
||||
|
||||
if res.get("status", False):
|
||||
# Define a generator to stream responses
|
||||
async def event_generator():
|
||||
nonlocal q
|
||||
try:
|
||||
while True:
|
||||
data = await q.get() # Wait for new messages
|
||||
if isinstance(data, dict):
|
||||
if "done" in data and data["done"]:
|
||||
break # Stop streaming when 'done' is received
|
||||
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
elif isinstance(data, str):
|
||||
yield data
|
||||
except Exception as e:
|
||||
log.debug(f"Error in event generator: {e}")
|
||||
pass
|
||||
|
||||
# Define a background task to run the event generator
|
||||
async def background():
|
||||
try:
|
||||
del sio.handlers["/"][channel]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Return the streaming response
|
||||
return StreamingResponse(
|
||||
event_generator(), media_type="text/event-stream", background=background
|
||||
)
|
||||
else:
|
||||
raise Exception(str(res))
|
||||
else:
|
||||
res = await event_caller(
|
||||
{
|
||||
"type": "request:chat:completion",
|
||||
"data": {
|
||||
"form_data": form_data,
|
||||
"model": models[form_data["model"]],
|
||||
"channel": channel,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if "error" in res:
|
||||
raise Exception(res["error"])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: Any,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
log.debug(f"generate_chat_completion: {form_data}")
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
models = request.app.state.MODELS
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
log.debug(f"direct connection to model: {models}")
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
@@ -80,85 +194,96 @@ async def generate_chat_completion(
|
||||
|
||||
model = models[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
if getattr(request.state, "direct", False):
|
||||
return await generate_direct_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||
)
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
||||
if model_ids and filter_mode == "exclude":
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||
]
|
||||
|
||||
selected_model_id = None
|
||||
if isinstance(model_ids, list) and model_ids:
|
||||
selected_model_id = random.choice(model_ids)
|
||||
else:
|
||||
model_ids = [
|
||||
model["id"]
|
||||
for model in list(request.app.state.MODELS.values())
|
||||
if model.get("owned_by") != "arena"
|
||||
]
|
||||
selected_model_id = random.choice(model_ids)
|
||||
|
||||
form_data["model"] = selected_model_id
|
||||
|
||||
if form_data.get("stream") == True:
|
||||
|
||||
async def stream_wrapper(stream):
|
||||
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
response = await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**(
|
||||
await generate_chat_completion(
|
||||
request, form_data, user, bypass_filter=True
|
||||
)
|
||||
),
|
||||
"selected_model_id": selected_model_id,
|
||||
}
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(
|
||||
request, form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
response = await generate_ollama_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
if form_data.get("stream"):
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
bypass_filter=bypass_filter,
|
||||
)
|
||||
|
||||
|
||||
chat_completion = generate_chat_completion
|
||||
@@ -167,7 +292,13 @@ chat_completion = generate_chat_completion
|
||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
@@ -227,7 +358,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
models = request.app.state.MODELS
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
@@ -616,7 +616,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
models = request.app.state.MODELS
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
task_model_id = get_task_model_id(
|
||||
form_data["model"],
|
||||
request.app.state.config.TASK_MODEL,
|
||||
@@ -766,17 +772,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
if "document" in source:
|
||||
for doc_idx, doc_context in enumerate(source["document"]):
|
||||
doc_metadata = source.get("metadata")
|
||||
doc_source_id = None
|
||||
|
||||
if doc_metadata:
|
||||
doc_source_id = doc_metadata[doc_idx].get("source", source_id)
|
||||
|
||||
if source_id:
|
||||
context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
else:
|
||||
# If there is no source_id, then do not include the source_id tag
|
||||
context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
|
||||
context_string += f"<source><source_id>{doc_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -1149,6 +1145,46 @@ async def process_chat_response(
|
||||
|
||||
return content.strip()
|
||||
|
||||
def convert_content_blocks_to_messages(content_blocks):
|
||||
messages = []
|
||||
|
||||
temp_blocks = []
|
||||
for idx, block in enumerate(content_blocks):
|
||||
if block["type"] == "tool_calls":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(temp_blocks),
|
||||
"tool_calls": block.get("content"),
|
||||
}
|
||||
)
|
||||
|
||||
results = block.get("results", [])
|
||||
|
||||
for result in results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
)
|
||||
temp_blocks = []
|
||||
else:
|
||||
temp_blocks.append(block)
|
||||
|
||||
if temp_blocks:
|
||||
content = serialize_content_blocks(temp_blocks)
|
||||
if content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def tag_content_handler(content_type, tags, content, content_blocks):
|
||||
end_flag = False
|
||||
|
||||
@@ -1540,7 +1576,6 @@ async def process_chat_response(
|
||||
|
||||
results = []
|
||||
for tool_call in response_tool_calls:
|
||||
print("\n\n" + str(tool_call) + "\n\n")
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
|
||||
@@ -1606,23 +1641,10 @@ async def process_chat_response(
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"tools": form_data["tools"],
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
"tool_calls": response_tool_calls,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
for result in results
|
||||
],
|
||||
*convert_content_blocks_to_messages(content_blocks),
|
||||
],
|
||||
},
|
||||
user,
|
||||
@@ -1671,6 +1693,9 @@ async def process_chat_response(
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": code,
|
||||
"session_id": metadata.get(
|
||||
"session_id", None
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1699,10 +1724,12 @@ async def process_chat_response(
|
||||
"stdout": "Code interpreter engine not configured."
|
||||
}
|
||||
|
||||
log.debug(f"Code interpreter output: {output}")
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
if stdout:
|
||||
if isinstance(stdout, str):
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
@@ -1734,7 +1761,7 @@ async def process_chat_response(
|
||||
|
||||
result = output.get("result", "")
|
||||
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
@@ -1784,6 +1811,8 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
print(content_blocks, serialize_content_blocks(content_blocks))
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
|
||||
@@ -217,12 +217,20 @@ def openai_chat_chunk_message_template(
|
||||
|
||||
|
||||
def openai_chat_completion_message_template(
|
||||
model: str, message: Optional[str] = None, usage: Optional[dict] = None
|
||||
model: str,
|
||||
message: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion"
|
||||
if message is not None:
|
||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
||||
template["choices"][0]["message"] = {
|
||||
"content": message,
|
||||
"role": "assistant",
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
|
||||
if usage:
|
||||
|
||||
@@ -6,9 +6,32 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
|
||||
|
||||
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
return openai_tool_calls
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
data = ollama_response
|
||||
usage = {
|
||||
@@ -51,7 +74,9 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
),
|
||||
}
|
||||
|
||||
response = openai_chat_completion_message_template(model, message_content, usage)
|
||||
response = openai_chat_completion_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@@ -65,20 +90,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
openai_tool_calls = None
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_call = {
|
||||
"index": tool_call.get("index", 0),
|
||||
"id": tool_call.get("id", f"call_{str(uuid4())}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
|
||||
|
||||
done = data.get("done", False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user