mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into fix/oidc-500-error-name-field
This commit is contained in:
@@ -1,9 +1,30 @@
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
|
||||
from open_webui.config import DEFAULT_USER_PERMISSIONS
|
||||
import json
|
||||
|
||||
|
||||
def fill_missing_permissions(
|
||||
permissions: Dict[str, Any], default_permissions: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively fills in missing properties in the permissions dictionary
|
||||
using the default permissions as a template.
|
||||
"""
|
||||
for key, value in default_permissions.items():
|
||||
if key not in permissions:
|
||||
permissions[key] = value
|
||||
elif isinstance(value, dict) and isinstance(
|
||||
permissions[key], dict
|
||||
): # Both are nested dictionaries
|
||||
permissions[key] = fill_missing_permissions(permissions[key], value)
|
||||
|
||||
return permissions
|
||||
|
||||
|
||||
def get_permissions(
|
||||
user_id: str,
|
||||
default_permissions: Dict[str, Any],
|
||||
@@ -27,39 +48,45 @@ def get_permissions(
|
||||
if key not in permissions:
|
||||
permissions[key] = value
|
||||
else:
|
||||
permissions[key] = permissions[key] or value
|
||||
permissions[key] = (
|
||||
permissions[key] or value
|
||||
) # Use the most permissive value (True > False)
|
||||
return permissions
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
|
||||
# deep copy default permissions to avoid modifying the original dict
|
||||
# Deep copy default permissions to avoid modifying the original dict
|
||||
permissions = json.loads(json.dumps(default_permissions))
|
||||
|
||||
# Combine permissions from all user groups
|
||||
for group in user_groups:
|
||||
group_permissions = group.permissions
|
||||
permissions = combine_permissions(permissions, group_permissions)
|
||||
|
||||
# Ensure all fields from default_permissions are present and filled in
|
||||
permissions = fill_missing_permissions(permissions, default_permissions)
|
||||
|
||||
return permissions
|
||||
|
||||
|
||||
def has_permission(
|
||||
user_id: str,
|
||||
permission_key: str,
|
||||
default_permissions: Dict[str, bool] = {},
|
||||
default_permissions: Dict[str, Any] = {},
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has a specific permission by checking the group permissions
|
||||
and falls back to default permissions if not found in any group.
|
||||
and fall back to default permissions if not found in any group.
|
||||
|
||||
Permission keys can be hierarchical and separated by dots ('.').
|
||||
"""
|
||||
|
||||
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
|
||||
def get_permission(permissions: Dict[str, Any], keys: List[str]) -> bool:
|
||||
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
|
||||
for key in keys:
|
||||
if key not in permissions:
|
||||
return False # If any part of the hierarchy is missing, deny access
|
||||
permissions = permissions[key] # Go one level deeper
|
||||
permissions = permissions[key] # Traverse one level deeper
|
||||
|
||||
return bool(permissions) # Return the boolean at the final level
|
||||
|
||||
@@ -73,7 +100,10 @@ def has_permission(
|
||||
if get_permission(group_permissions, permission_hierarchy):
|
||||
return True
|
||||
|
||||
# Check default permissions afterwards if the group permissions don't allow it
|
||||
# Check default permissions afterward if the group permissions don't allow it
|
||||
default_permissions = fill_missing_permissions(
|
||||
default_permissions, DEFAULT_USER_PERMISSIONS
|
||||
)
|
||||
return get_permission(default_permissions, permission_hierarchy)
|
||||
|
||||
|
||||
|
||||
@@ -28,9 +28,13 @@ from open_webui.socket.main import (
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_image_prompt,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.routers.images import image_generations, GenerateImageForm
|
||||
|
||||
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
|
||||
@@ -486,6 +490,100 @@ async def chat_web_search_handler(
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_image_generation_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
__event_emitter__ = extra_params["__event_emitter__"]
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {"description": "Generating an image", "done": False},
|
||||
}
|
||||
)
|
||||
|
||||
messages = form_data["messages"]
|
||||
user_message = get_last_user_message(messages)
|
||||
|
||||
prompt = user_message
|
||||
negative_prompt = ""
|
||||
|
||||
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
|
||||
try:
|
||||
res = await generate_image_prompt(
|
||||
request,
|
||||
{
|
||||
"model": form_data["model"],
|
||||
"messages": messages,
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
response = res["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
bracket_start = response.find("{")
|
||||
bracket_end = response.rfind("}") + 1
|
||||
|
||||
if bracket_start == -1 or bracket_end == -1:
|
||||
raise Exception("No JSON object found in the response")
|
||||
|
||||
response = response[bracket_start:bracket_end]
|
||||
response = json.loads(response)
|
||||
prompt = response.get("prompt", [])
|
||||
except Exception as e:
|
||||
prompt = user_message
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
prompt = user_message
|
||||
|
||||
system_message_content = ""
|
||||
|
||||
try:
|
||||
images = await image_generations(
|
||||
request=request,
|
||||
form_data=GenerateImageForm(**{"prompt": prompt}),
|
||||
user=user,
|
||||
)
|
||||
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {"description": "Generated an image", "done": True},
|
||||
}
|
||||
)
|
||||
|
||||
for image in images:
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "message",
|
||||
"data": {"content": f"\n"},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"description": f"An error occured while generating an image",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
|
||||
|
||||
if system_message_content:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system_message_content, form_data["messages"]
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_completion_files_handler(
|
||||
request: Request, body: dict, user: UserModel
|
||||
) -> tuple[dict, dict[str, list]]:
|
||||
@@ -523,17 +621,28 @@ async def chat_completion_files_handler(
|
||||
if len(queries) == 0:
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
sources = get_sources_from_files(
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
try:
|
||||
# Offload get_sources_from_files to a separate thread
|
||||
loop = asyncio.get_running_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
sources = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: get_sources_from_files(
|
||||
files=files,
|
||||
queries=queries,
|
||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||
k=request.app.state.config.TOP_K,
|
||||
reranking_function=request.app.state.rf,
|
||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
log.debug(f"rag_contexts:sources: {sources}")
|
||||
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
@@ -562,6 +671,10 @@ def apply_params_to_form_data(form_data, model):
|
||||
|
||||
if "frequency_penalty" in params:
|
||||
form_data["frequency_penalty"] = params["frequency_penalty"]
|
||||
|
||||
if "reasoning_effort" in params:
|
||||
form_data["reasoning_effort"] = params["reasoning_effort"]
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
@@ -640,6 +753,11 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
if "image_generation" in features and features["image_generation"]:
|
||||
form_data = await chat_image_generation_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
@@ -958,6 +1076,16 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
# We might want to disable this by default
|
||||
detect_reasoning = True
|
||||
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
|
||||
current_tag = None
|
||||
|
||||
reasoning_start_time = None
|
||||
|
||||
reasoning_content = ""
|
||||
ongoing_content = ""
|
||||
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
@@ -966,12 +1094,12 @@ async def process_chat_response(
|
||||
if not data.strip():
|
||||
continue
|
||||
|
||||
# "data: " is the prefix for each event
|
||||
if not data.startswith("data: "):
|
||||
# "data:" is the prefix for each event
|
||||
if not data.startswith("data:"):
|
||||
continue
|
||||
|
||||
# Remove the prefix
|
||||
data = data[len("data: ") :]
|
||||
data = data[len("data:") :].strip()
|
||||
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -984,7 +1112,6 @@ async def process_chat_response(
|
||||
"selectedModelId": data["selected_model_id"],
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
value = (
|
||||
data.get("choices", [])[0]
|
||||
@@ -995,6 +1122,73 @@ async def process_chat_response(
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
if detect_reasoning:
|
||||
for tag in reasoning_tags:
|
||||
start_tag = f"<{tag}>\n"
|
||||
end_tag = f"</{tag}>\n"
|
||||
|
||||
if start_tag in content:
|
||||
# Remove the start tag
|
||||
content = content.replace(start_tag, "")
|
||||
ongoing_content = content
|
||||
|
||||
reasoning_start_time = time.time()
|
||||
reasoning_content = ""
|
||||
|
||||
current_tag = tag
|
||||
break
|
||||
|
||||
if reasoning_start_time is not None:
|
||||
# Remove the last value from the content
|
||||
content = content[: -len(value)]
|
||||
|
||||
reasoning_content += value
|
||||
|
||||
end_tag = f"</{current_tag}>\n"
|
||||
if end_tag in reasoning_content:
|
||||
reasoning_end_time = time.time()
|
||||
reasoning_duration = int(
|
||||
reasoning_end_time
|
||||
- reasoning_start_time
|
||||
)
|
||||
reasoning_content = (
|
||||
reasoning_content.strip(
|
||||
f"<{current_tag}>\n"
|
||||
)
|
||||
.strip(end_tag)
|
||||
.strip()
|
||||
)
|
||||
|
||||
if reasoning_content:
|
||||
reasoning_display_content = "\n".join(
|
||||
(
|
||||
f"> {line}"
|
||||
if not line.startswith(">")
|
||||
else line
|
||||
)
|
||||
for line in reasoning_content.splitlines()
|
||||
)
|
||||
|
||||
# Format reasoning with <details> tag
|
||||
content = f'{ongoing_content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
else:
|
||||
content = ""
|
||||
|
||||
reasoning_start_time = None
|
||||
else:
|
||||
|
||||
reasoning_display_content = "\n".join(
|
||||
(
|
||||
f"> {line}"
|
||||
if not line.startswith(">")
|
||||
else line
|
||||
)
|
||||
for line in reasoning_content.splitlines()
|
||||
)
|
||||
|
||||
# Show ongoing thought process
|
||||
content = f'{ongoing_content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
|
||||
if ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
@@ -1015,10 +1209,8 @@ async def process_chat_response(
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
done = "data: [DONE]" in line
|
||||
|
||||
if done:
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -63,17 +63,8 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
class OAuthManager:
|
||||
def __init__(self):
|
||||
self.oauth = OAuth()
|
||||
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
||||
self.oauth.register(
|
||||
name=provider_name,
|
||||
client_id=provider_config["client_id"],
|
||||
client_secret=provider_config["client_secret"],
|
||||
server_metadata_url=provider_config["server_metadata_url"],
|
||||
client_kwargs={
|
||||
"scope": provider_config["scope"],
|
||||
},
|
||||
redirect_uri=provider_config["redirect_uri"],
|
||||
)
|
||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
||||
provider_config["register"](self.oauth)
|
||||
|
||||
def get_client(self, provider_name):
|
||||
return self.oauth.create_client(provider_name)
|
||||
@@ -200,14 +191,14 @@ class OAuthManager:
|
||||
except Exception as e:
|
||||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token["userinfo"]
|
||||
user_data: UserInfo = token.get("userinfo")
|
||||
if not user_data:
|
||||
user_data: UserInfo = await client.userinfo(token=token)
|
||||
if not user_data:
|
||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
sub = user_data.get("sub")
|
||||
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
||||
if not sub:
|
||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
@@ -255,12 +246,20 @@ class OAuthManager:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
||||
picture_url = user_data.get(picture_claim, "")
|
||||
picture_url = user_data.get(
|
||||
picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
|
||||
)
|
||||
if picture_url:
|
||||
# Download the profile image into a base64 string
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
get_kwargs = {}
|
||||
if access_token:
|
||||
get_kwargs["headers"] = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url) as resp:
|
||||
async with session.get(picture_url, **get_kwargs) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
|
||||
@@ -47,6 +47,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
"top_p": float,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": float,
|
||||
"reasoning_effort": str,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
@@ -217,6 +217,24 @@ def tags_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def image_prompt_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def emoji_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user