mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into bug/user-signup/fix-oauth-username-claim-has-no-effect
This commit is contained in:
@@ -44,6 +44,10 @@ from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
|
||||
@@ -177,116 +181,38 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
__event_emitter__ = get_event_emitter(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
metadata = {
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
|
||||
__event_call__ = get_event_call(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
# Sort filter_ids by priority, using the get_priority function
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if not hasattr(function_module, "outlet"):
|
||||
continue
|
||||
try:
|
||||
outlet = function_module.outlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(outlet)
|
||||
params = {"body": data}
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__request__": request,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(**params)
|
||||
else:
|
||||
data = outlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
return data
|
||||
try:
|
||||
result, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="outlet",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
|
||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||
|
||||
153
backend/open_webui/utils/code_interpreter.py
Normal file
153
backend/open_webui/utils/code_interpreter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import websockets
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
async def execute_code_jupyter(
|
||||
jupyter_url, code, token=None, password=None, timeout=10
|
||||
):
|
||||
"""
|
||||
Executes Python code in a Jupyter kernel.
|
||||
Supports authentication with a token or password.
|
||||
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
|
||||
:param code: Code to execute
|
||||
:param token: Jupyter authentication token (optional)
|
||||
:param password: Jupyter password (optional)
|
||||
:param timeout: WebSocket timeout in seconds (default: 10s)
|
||||
:return: Dictionary with stdout, stderr, and result
|
||||
"""
|
||||
session = requests.Session() # Maintain cookies
|
||||
headers = {} # Headers for requests
|
||||
|
||||
# Authenticate using password
|
||||
if password and not token:
|
||||
try:
|
||||
login_url = urljoin(jupyter_url, "/login")
|
||||
response = session.get(login_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Retrieve `_xsrf` token
|
||||
xsrf_token = session.cookies.get("_xsrf")
|
||||
if not xsrf_token:
|
||||
raise ValueError("Failed to fetch _xsrf token")
|
||||
|
||||
# Send login request
|
||||
login_data = {"_xsrf": xsrf_token, "password": password}
|
||||
login_response = session.post(
|
||||
login_url, data=login_data, cookies=session.cookies
|
||||
)
|
||||
login_response.raise_for_status()
|
||||
|
||||
# Update headers with `_xsrf`
|
||||
headers["X-XSRFToken"] = xsrf_token
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Authentication Error: {str(e)}",
|
||||
"result": "",
|
||||
}
|
||||
|
||||
# Construct API URLs with authentication token if provided
|
||||
params = f"?token={token}" if token else ""
|
||||
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
|
||||
|
||||
try:
|
||||
# Include cookies if authenticating with password
|
||||
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
|
||||
response.raise_for_status()
|
||||
kernel_id = response.json()["id"]
|
||||
|
||||
# Construct WebSocket URL
|
||||
websocket_url = urljoin(
|
||||
jupyter_url.replace("http", "ws"),
|
||||
f"/api/kernels/{kernel_id}/channels{params}",
|
||||
)
|
||||
|
||||
# **IMPORTANT:** Include authentication cookies for WebSockets
|
||||
ws_headers = {}
|
||||
if password and not token:
|
||||
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
|
||||
cookies = {name: value for name, value in session.cookies.items()}
|
||||
ws_headers["Cookie"] = "; ".join(
|
||||
[f"{name}={value}" for name, value in cookies.items()]
|
||||
)
|
||||
|
||||
# Connect to the WebSocket
|
||||
async with websockets.connect(
|
||||
websocket_url, additional_headers=ws_headers
|
||||
) as ws:
|
||||
msg_id = str(uuid.uuid4())
|
||||
|
||||
# Send execution request
|
||||
execute_request = {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "user",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": "",
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"channel": "shell",
|
||||
}
|
||||
await ws.send(json.dumps(execute_request))
|
||||
|
||||
# Collect execution results
|
||||
stdout, stderr, result = "", "", None
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(ws.recv(), timeout)
|
||||
message_data = json.loads(message)
|
||||
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
|
||||
msg_type = message_data.get("msg_type")
|
||||
if msg_type == "stream":
|
||||
if message_data["content"]["name"] == "stdout":
|
||||
stdout += message_data["content"]["text"]
|
||||
elif message_data["content"]["name"] == "stderr":
|
||||
stderr += message_data["content"]["text"]
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
result = message_data["content"]["data"].get(
|
||||
"text/plain", ""
|
||||
)
|
||||
elif msg_type == "error":
|
||||
stderr += "\n".join(message_data["content"]["traceback"])
|
||||
elif (
|
||||
msg_type == "status"
|
||||
and message_data["content"]["execution_state"] == "idle"
|
||||
):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
stderr += "\nExecution timed out."
|
||||
break
|
||||
except Exception as e:
|
||||
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
|
||||
finally:
|
||||
# Shutdown the kernel
|
||||
if kernel_id:
|
||||
requests.delete(
|
||||
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
|
||||
)
|
||||
|
||||
return {
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr.strip(),
|
||||
"result": result.strip() if result else "",
|
||||
}
|
||||
|
||||
|
||||
# Example Usage
|
||||
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token"))
|
||||
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password"))
|
||||
99
backend/open_webui/utils/filter.py
Normal file
99
backend/open_webui/utils/filter.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import inspect
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.models.functions import Functions
|
||||
|
||||
|
||||
def get_sorted_filter_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
|
||||
async def process_filter_functions(
|
||||
request, filter_ids, filter_type, form_data, extra_params
|
||||
):
|
||||
skip_files = None
|
||||
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
# Prepare handler function
|
||||
handler = getattr(function_module, filter_type, None)
|
||||
if not handler:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Prepare parameters
|
||||
sig = inspect.signature(handler)
|
||||
params = {"body": form_data} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in sig.parameters
|
||||
}
|
||||
|
||||
# Handle user parameters
|
||||
if "__user__" in sig.parameters:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Execute handler
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
form_data = await handler(**params)
|
||||
else:
|
||||
form_data = handler(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||
raise e
|
||||
|
||||
# Handle file cleanup for inlet
|
||||
if skip_files and "files" in form_data.get("metadata", {}):
|
||||
del form_data["metadata"]["files"]
|
||||
|
||||
return form_data, {}
|
||||
@@ -161,7 +161,7 @@ async def comfyui_generate_image(
|
||||
seed = (
|
||||
payload.seed
|
||||
if payload.seed
|
||||
else random.randint(0, 18446744073709551614)
|
||||
else random.randint(0, 1125899906842624)
|
||||
)
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"][node.key] = seed
|
||||
|
||||
@@ -68,7 +68,11 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if hasattr(function_module, "inlet"):
|
||||
try:
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Create a dictionary of parameters to be passed to the function
|
||||
params = {"body": body} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in inspect.signature(inlet).parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
else:
|
||||
body = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
@@ -572,13 +483,13 @@ async def chat_image_generation_handler(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"description": f"An error occured while generating an image",
|
||||
"description": f"An error occurred while generating an image",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
||||
|
||||
if system_message_content:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
@@ -706,6 +617,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
@@ -782,8 +694,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
form_data, flags = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="inlet",
|
||||
form_data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
@@ -1122,6 +1038,20 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
def split_content_and_whitespace(content):
|
||||
content_stripped = content.rstrip()
|
||||
original_whitespace = (
|
||||
content[len(content_stripped) :]
|
||||
if len(content) > len(content_stripped)
|
||||
else ""
|
||||
)
|
||||
return content_stripped, original_whitespace
|
||||
|
||||
def is_opening_code_block(content):
|
||||
backtick_segments = content.split("```")
|
||||
# Even number of segments means the last backticks are opening a new block
|
||||
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
||||
|
||||
# Handle as a background task
|
||||
async def post_response_handler(response, events):
|
||||
def serialize_content_blocks(content_blocks, raw=False):
|
||||
@@ -1188,6 +1118,19 @@ async def process_chat_response(
|
||||
output = block.get("output", None)
|
||||
lang = attributes.get("lang", "")
|
||||
|
||||
content_stripped, original_whitespace = (
|
||||
split_content_and_whitespace(content)
|
||||
)
|
||||
if is_opening_code_block(content_stripped):
|
||||
# Remove trailing backticks that would open a new block
|
||||
content = (
|
||||
content_stripped.rstrip("`").rstrip()
|
||||
+ original_whitespace
|
||||
)
|
||||
else:
|
||||
# Keep content as is - either closing backticks or no backticks
|
||||
content = content_stripped + original_whitespace
|
||||
|
||||
if output:
|
||||
output = html.escape(json.dumps(output))
|
||||
|
||||
@@ -1242,10 +1185,10 @@ async def process_chat_response(
|
||||
match.end() :
|
||||
] # Content after opening tag
|
||||
|
||||
# Remove the start tag from the currently handling text block
|
||||
# Remove the start tag and after from the currently handling text block
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].replace(match.group(0), "")
|
||||
].replace(match.group(0) + after_tag, "")
|
||||
|
||||
if before_tag:
|
||||
content_blocks[-1]["content"] = before_tag
|
||||
@@ -1708,15 +1651,45 @@ async def process_chat_response(
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
code = content_blocks[-1]["content"]
|
||||
|
||||
if (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
== "pyodide"
|
||||
):
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
== "jupyter"
|
||||
):
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
code,
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
== "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
== "password"
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
output = {
|
||||
"stdout": "Code interpreter engine not configured."
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
@@ -244,11 +244,12 @@ def get_gravatar_url(email):
|
||||
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
|
||||
|
||||
|
||||
def calculate_sha256(file):
|
||||
def calculate_sha256(file_path, chunk_size):
|
||||
# Compute SHA-256 hash of a file efficiently in chunks
|
||||
sha256 = hashlib.sha256()
|
||||
# Read the file in chunks to efficiently handle large files
|
||||
for chunk in iter(lambda: file.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
@@ -40,7 +41,11 @@ from open_webui.utils.misc import parse_duration
|
||||
from open_webui.utils.auth import get_password_hash, create_token
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||
|
||||
auth_manager_config = AppConfig()
|
||||
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
@@ -72,12 +77,15 @@ class OAuthManager:
|
||||
def get_user_role(self, user, user_data):
|
||||
if user and Users.get_num_users() == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
log.debug("Assigning the only user the admin role")
|
||||
return "admin"
|
||||
if not user and Users.get_num_users() == 0:
|
||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||
log.debug("Assigning the first user the admin role")
|
||||
return "admin"
|
||||
|
||||
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||
log.debug("Running OAUTH Role management")
|
||||
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||
@@ -93,17 +101,24 @@ class OAuthManager:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||
|
||||
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||
log.debug(f"User roles from oauth: {oauth_roles}")
|
||||
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
|
||||
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
|
||||
|
||||
# If any roles are found, check if they match the allowed or admin roles
|
||||
if oauth_roles:
|
||||
# If role management is enabled, and matching roles are provided, use the roles
|
||||
for allowed_role in oauth_allowed_roles:
|
||||
# If the user has any of the allowed roles, assign the role "user"
|
||||
if allowed_role in oauth_roles:
|
||||
log.debug("Assigned user the user role")
|
||||
role = "user"
|
||||
break
|
||||
for admin_role in oauth_admin_roles:
|
||||
# If the user has any of the admin roles, assign the role "admin"
|
||||
if admin_role in oauth_roles:
|
||||
log.debug("Assigned user the admin role")
|
||||
role = "admin"
|
||||
break
|
||||
else:
|
||||
@@ -117,16 +132,27 @@ class OAuthManager:
|
||||
return role
|
||||
|
||||
def update_user_groups(self, user, user_data, default_permissions):
|
||||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
log.debug(
|
||||
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||
)
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if group_model.name not in user_oauth_groups:
|
||||
# Remove group from user
|
||||
log.debug(
|
||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
@@ -152,6 +178,9 @@ class OAuthManager:
|
||||
gm.name == group_model.name for gm in user_current_groups
|
||||
):
|
||||
# Add user to group
|
||||
log.debug(
|
||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
@@ -193,7 +222,7 @@ class OAuthManager:
|
||||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token.get("userinfo")
|
||||
if not user_data:
|
||||
if not user_data or "email" not in user_data:
|
||||
user_data: UserInfo = await client.userinfo(token=token)
|
||||
if not user_data:
|
||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||
@@ -261,15 +290,20 @@ class OAuthManager:
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url, **get_kwargs) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
if resp.ok:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(
|
||||
picture_url
|
||||
)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
else:
|
||||
picture_url = "/user.png"
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error downloading profile image '{picture_url}': {e}"
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from html import escape
|
||||
|
||||
from markdown import markdown
|
||||
|
||||
@@ -41,13 +42,13 @@ class PDFGenerator:
|
||||
|
||||
def _build_html_message(self, message: Dict[str, Any]) -> str:
|
||||
"""Build HTML for a single message."""
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
role = escape(message.get("role", "user"))
|
||||
content = escape(message.get("content", ""))
|
||||
timestamp = message.get("timestamp")
|
||||
|
||||
model = message.get("model") if role == "assistant" else ""
|
||||
model = escape(message.get("model") if role == "assistant" else "")
|
||||
|
||||
date_str = self.format_timestamp(timestamp) if timestamp else ""
|
||||
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
|
||||
|
||||
# extends pymdownx extension to convert markdown to html.
|
||||
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
|
||||
@@ -76,6 +77,7 @@ class PDFGenerator:
|
||||
|
||||
def _generate_html_body(self) -> str:
|
||||
"""Generate the full HTML body for the PDF."""
|
||||
escaped_title = escape(self.form_data.title)
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
@@ -84,7 +86,7 @@ class PDFGenerator:
|
||||
<body>
|
||||
<div>
|
||||
<div>
|
||||
<h2>{self.form_data.title}</h2>
|
||||
<h2>{escaped_title}</h2>
|
||||
{self.messages_html}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user