Merge branch 'open-webui:main' into fix/oidc-500-error-name-field

This commit is contained in:
Kevin Wang
2025-01-27 13:10:08 +01:00
committed by GitHub
193 changed files with 7697 additions and 3905 deletions

View File

@@ -1,9 +1,30 @@
from typing import Optional, Union, List, Dict, Any
from open_webui.models.users import Users, UserModel
from open_webui.models.groups import Groups
from open_webui.config import DEFAULT_USER_PERMISSIONS
import json
def fill_missing_permissions(
permissions: Dict[str, Any], default_permissions: Dict[str, Any]
) -> Dict[str, Any]:
"""
Recursively fills in missing properties in the permissions dictionary
using the default permissions as a template.
"""
for key, value in default_permissions.items():
if key not in permissions:
permissions[key] = value
elif isinstance(value, dict) and isinstance(
permissions[key], dict
): # Both are nested dictionaries
permissions[key] = fill_missing_permissions(permissions[key], value)
return permissions
def get_permissions(
user_id: str,
default_permissions: Dict[str, Any],
@@ -27,39 +48,45 @@ def get_permissions(
if key not in permissions:
permissions[key] = value
else:
permissions[key] = permissions[key] or value
permissions[key] = (
permissions[key] or value
) # Use the most permissive value (True > False)
return permissions
user_groups = Groups.get_groups_by_member_id(user_id)
# deep copy default permissions to avoid modifying the original dict
# Deep copy default permissions to avoid modifying the original dict
permissions = json.loads(json.dumps(default_permissions))
# Combine permissions from all user groups
for group in user_groups:
group_permissions = group.permissions
permissions = combine_permissions(permissions, group_permissions)
# Ensure all fields from default_permissions are present and filled in
permissions = fill_missing_permissions(permissions, default_permissions)
return permissions
def has_permission(
user_id: str,
permission_key: str,
default_permissions: Dict[str, bool] = {},
default_permissions: Dict[str, Any] = {},
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
and falls back to default permissions if not found in any group.
and fall back to default permissions if not found in any group.
Permission keys can be hierarchical and separated by dots ('.').
"""
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
def get_permission(permissions: Dict[str, Any], keys: List[str]) -> bool:
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
for key in keys:
if key not in permissions:
return False # If any part of the hierarchy is missing, deny access
permissions = permissions[key] # Go one level deeper
permissions = permissions[key] # Traverse one level deeper
return bool(permissions) # Return the boolean at the final level
@@ -73,7 +100,10 @@ def has_permission(
if get_permission(group_permissions, permission_hierarchy):
return True
# Check default permissions afterwards if the group permissions don't allow it
# Check default permissions afterward if the group permissions don't allow it
default_permissions = fill_missing_permissions(
default_permissions, DEFAULT_USER_PERMISSIONS
)
return get_permission(default_permissions, permission_hierarchy)

View File

@@ -28,9 +28,13 @@ from open_webui.socket.main import (
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.utils.webhook import post_webhook
@@ -486,6 +490,100 @@ async def chat_web_search_handler(
return form_data
async def chat_image_generation_handler(
request: Request, form_data: dict, extra_params: dict, user
):
__event_emitter__ = extra_params["__event_emitter__"]
await __event_emitter__(
{
"type": "status",
"data": {"description": "Generating an image", "done": False},
}
)
messages = form_data["messages"]
user_message = get_last_user_message(messages)
prompt = user_message
negative_prompt = ""
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
try:
res = await generate_image_prompt(
request,
{
"model": form_data["model"],
"messages": messages,
},
user,
)
response = res["choices"][0]["message"]["content"]
try:
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1
if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")
response = response[bracket_start:bracket_end]
response = json.loads(response)
prompt = response.get("prompt", [])
except Exception as e:
prompt = user_message
except Exception as e:
log.exception(e)
prompt = user_message
system_message_content = ""
try:
images = await image_generations(
request=request,
form_data=GenerateImageForm(**{"prompt": prompt}),
user=user,
)
await __event_emitter__(
{
"type": "status",
"data": {"description": "Generated an image", "done": True},
}
)
for image in images:
await __event_emitter__(
{
"type": "message",
"data": {"content": f"![Generated Image]({image['url']})\n"},
}
)
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
except Exception as e:
log.exception(e)
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"An error occured while generating an image",
"done": True,
},
}
)
system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
if system_message_content:
form_data["messages"] = add_or_update_system_message(
system_message_content, form_data["messages"]
)
return form_data
async def chat_completion_files_handler(
request: Request, body: dict, user: UserModel
) -> tuple[dict, dict[str, list]]:
@@ -523,17 +621,28 @@ async def chat_completion_files_handler(
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
sources = get_sources_from_files(
files=files,
queries=queries,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
try:
# Offload get_sources_from_files to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor(
executor,
lambda: get_sources_from_files(
files=files,
queries=queries,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
),
)
except Exception as e:
log.exception(e)
log.debug(f"rag_contexts:sources: {sources}")
return body, {"sources": sources}
@@ -562,6 +671,10 @@ def apply_params_to_form_data(form_data, model):
if "frequency_penalty" in params:
form_data["frequency_penalty"] = params["frequency_penalty"]
if "reasoning_effort" in params:
form_data["reasoning_effort"] = params["reasoning_effort"]
return form_data
@@ -640,6 +753,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
request, form_data, extra_params, user
)
if "image_generation" in features and features["image_generation"]:
form_data = await chat_image_generation_handler(
request, form_data, extra_params, user
)
try:
form_data, flags = await chat_completion_filter_functions_handler(
request, form_data, model, extra_params
@@ -958,6 +1076,16 @@ async def process_chat_response(
},
)
# We might want to disable this by default
detect_reasoning = True
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
current_tag = None
reasoning_start_time = None
reasoning_content = ""
ongoing_content = ""
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
@@ -966,12 +1094,12 @@ async def process_chat_response(
if not data.strip():
continue
# "data: " is the prefix for each event
if not data.startswith("data: "):
# "data:" is the prefix for each event
if not data.startswith("data:"):
continue
# Remove the prefix
data = data[len("data: ") :]
data = data[len("data:") :].strip()
try:
data = json.loads(data)
@@ -984,7 +1112,6 @@ async def process_chat_response(
"selectedModelId": data["selected_model_id"],
},
)
else:
value = (
data.get("choices", [])[0]
@@ -995,6 +1122,73 @@ async def process_chat_response(
if value:
content = f"{content}{value}"
if detect_reasoning:
for tag in reasoning_tags:
start_tag = f"<{tag}>\n"
end_tag = f"</{tag}>\n"
if start_tag in content:
# Remove the start tag
content = content.replace(start_tag, "")
ongoing_content = content
reasoning_start_time = time.time()
reasoning_content = ""
current_tag = tag
break
if reasoning_start_time is not None:
# Remove the last value from the content
content = content[: -len(value)]
reasoning_content += value
end_tag = f"</{current_tag}>\n"
if end_tag in reasoning_content:
reasoning_end_time = time.time()
reasoning_duration = int(
reasoning_end_time
- reasoning_start_time
)
reasoning_content = (
reasoning_content.strip(
f"<{current_tag}>\n"
)
.strip(end_tag)
.strip()
)
if reasoning_content:
reasoning_display_content = "\n".join(
(
f"> {line}"
if not line.startswith(">")
else line
)
for line in reasoning_content.splitlines()
)
# Format reasoning with <details> tag
content = f'{ongoing_content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
else:
content = ""
reasoning_start_time = None
else:
reasoning_display_content = "\n".join(
(
f"> {line}"
if not line.startswith(">")
else line
)
for line in reasoning_content.splitlines()
)
# Show ongoing thought process
content = f'{ongoing_content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
@@ -1015,10 +1209,8 @@ async def process_chat_response(
"data": data,
}
)
except Exception as e:
done = "data: [DONE]" in line
if done:
pass
else:

View File

@@ -63,17 +63,8 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
for _, provider_config in OAUTH_PROVIDERS.items():
provider_config["register"](self.oauth)
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
@@ -200,14 +191,14 @@ class OAuthManager:
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
user_data: UserInfo = token.get("userinfo")
if not user_data:
user_data: UserInfo = await client.userinfo(token=token)
if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
sub = user_data.get("sub")
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@@ -255,12 +246,20 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
picture_url = user_data.get(
picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
)
if picture_url:
# Download the profile image into a base64 string
try:
access_token = token.get("access_token")
get_kwargs = {}
if access_token:
get_kwargs["headers"] = {
"Authorization": f"Bearer {access_token}",
}
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
async with session.get(picture_url, **get_kwargs) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture

View File

@@ -47,6 +47,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
"top_p": float,
"max_tokens": int,
"frequency_penalty": float,
"reasoning_effort": str,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}

View File

@@ -217,6 +217,24 @@ def tags_generation_template(
return template
def image_prompt_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def emoji_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str: