Merge branch 'dev' into bug/user-signup/fix-oauth-username-claim-has-no-effect

This commit is contained in:
pseudorm
2025-02-10 22:00:20 +08:00
committed by GitHub
102 changed files with 2853 additions and 1223 deletions

View File

@@ -660,6 +660,7 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
@@ -1325,6 +1326,48 @@ Your task is to synthesize these responses into a single, high-quality response.
Responses from models: {{responses}}"""
####################################
# Code Interpreter
####################################
ENABLE_CODE_INTERPRETER = PersistentConfig(
"ENABLE_CODE_INTERPRETER",
"code_interpreter.enable",
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
)
CODE_INTERPRETER_ENGINE = PersistentConfig(
"CODE_INTERPRETER_ENGINE",
"code_interpreter.engine",
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
)
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_URL",
"code_interpreter.jupyter.url",
os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
)
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH",
"code_interpreter.jupyter.auth",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
)
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
"code_interpreter.jupyter.auth_token",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
)
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
"code_interpreter.jupyter.auth_password",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
)
DEFAULT_CODE_INTERPRETER_PROMPT = """
#### Tools Available
@@ -1645,7 +1688,7 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
"rag.rag.web.search.domain.filter_list",
"rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
@@ -2012,6 +2055,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
# Add Deepgram configuration
DEEPGRAM_API_KEY = PersistentConfig(
"DEEPGRAM_API_KEY",
"audio.stt.deepgram.api_key",
os.getenv("DEEPGRAM_API_KEY", ""),
)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",

View File

@@ -92,6 +92,7 @@ log_sources = [
"RAG",
"WEBHOOK",
"SOCKET",
"OAUTH",
]
SRC_LOG_LEVELS = {}

View File

@@ -97,6 +97,13 @@ from open_webui.config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
# Code Interpreter
ENABLE_CODE_INTERPRETER,
CODE_INTERPRETER_ENGINE,
CODE_INTERPRETER_JUPYTER_URL,
CODE_INTERPRETER_JUPYTER_AUTH,
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
# Image
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
@@ -130,6 +137,7 @@ from open_webui.config import (
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
WHISPER_MODEL,
DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
# Retrieval
@@ -569,6 +577,23 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
########################################
#
# CODE INTERPRETER
#
########################################
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
)
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
########################################
#
@@ -611,6 +636,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
@@ -753,6 +779,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
@@ -1013,12 +1040,14 @@ async def get_app_config(request: Request):
{
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
}
if user is not None
else {}

View File

@@ -470,7 +470,7 @@ class ChatTable:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checkng if a chat with the share_id exists
# we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:

View File

@@ -113,6 +113,34 @@ class OpenSearchClient:
return self._result_to_search_result(result)
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}",
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)

View File

@@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": api_key,
"X-Retain-Images": "none",
}
payload = {"q": query, "count": count if count <= 10 else 10}
url = str(URL(jina_search_endpoint))
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"][:count]:
for result in data["data"]:
results.append(
SearchResult(
link=result["url"],

View File

@@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
import mimetypes
from fastapi import (
Depends,
@@ -138,6 +139,7 @@ class STTConfigForm(BaseModel):
ENGINE: str
MODEL: str
WHISPER_MODEL: str
DEEPGRAM_API_KEY: str
class AudioConfigUpdateForm(BaseModel):
@@ -165,6 +167,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@@ -190,6 +193,7 @@ async def update_audio_config(
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -214,6 +218,7 @@ async def update_audio_config(
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@@ -521,6 +526,69 @@ def transcribe(request: Request, file_path):
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
elif request.app.state.config.STT_ENGINE == "deepgram":
try:
# Determine the MIME type of the file
mime, _ = mimetypes.guess_type(file_path)
if not mime:
mime = "audio/wav" # fallback to wav if undetectable
# Read the audio file
with open(file_path, "rb") as f:
file_data = f.read()
# Build headers and parameters
headers = {
"Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
"Content-Type": mime,
}
# Add model if specified
params = {}
if request.app.state.config.STT_MODEL:
params["model"] = request.app.state.config.STT_MODEL
# Make request to Deepgram API
r = requests.post(
"https://api.deepgram.com/v1/listen",
headers=headers,
params=params,
data=file_data,
)
r.raise_for_status()
response_data = r.json()
# Extract transcript from Deepgram response
try:
transcript = response_data["results"]["channels"][0]["alternatives"][
0
].get("transcript", "")
except (KeyError, IndexError) as e:
log.error(f"Malformed response from Deepgram: {str(e)}")
raise Exception(
"Failed to parse Deepgram response - unexpected response format"
)
data = {"text": transcript.strip()}
# Save transcript
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
return data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:

View File

@@ -36,6 +36,61 @@ async def export_config(user=Depends(get_admin_user)):
return get_config()
############################
# CodeInterpreterConfig
############################
class CodeInterpreterConfigForm(BaseModel):
ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
}
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def set_code_interpreter_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
form_data.CODE_INTERPRETER_JUPYTER_URL
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
return {
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
}
############################
# SetDefaultModels
############################

View File

@@ -11,10 +11,8 @@ import re
import time
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
from fastapi import (
@@ -990,6 +988,8 @@ async def generate_chat_completion(
)
payload = {**form_data.model_dump(exclude_none=True)}
if "metadata" in payload:
del payload["metadata"]
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
@@ -1408,9 +1408,10 @@ async def download_model(
return None
# TODO: Progress bar does not reflect size & duration of upload.
@router.post("/models/upload")
@router.post("/models/upload/{url_idx}")
def upload_model(
async def upload_model(
request: Request,
file: UploadFile = File(...),
url_idx: Optional[int] = None,
@@ -1419,59 +1420,85 @@ def upload_model(
if url_idx is None:
url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = os.path.join(UPLOAD_DIR, file.filename)
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = f"{UPLOAD_DIR}/{file.filename}"
# --- P1: save file locally ---
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
with open(file_path, "wb") as out_f:
while True:
chunk = file.file.read(chunk_size)
# log.info(f"Chunk: {str(chunk)}") # DEBUG
if not chunk:
break
out_f.write(chunk)
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
async def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
# --- P2: SSE progress + calculate sha256 hash ---
file_hash = calculate_sha256(file_path, chunk_size)
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
bytes_read = 0
while chunk := f.read(chunk_size):
bytes_read += len(chunk)
progress = round(bytes_read / total_size * 100, 2)
data_msg = {
"progress": progress,
"total": total_size,
"completed": total,
"completed": bytes_read,
}
yield f"data: {json.dumps(res)}\n\n"
yield f"data: {json.dumps(data_msg)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
# --- P3: Upload to ollama /api/blobs ---
with open(file_path, "rb") as f:
url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
response = requests.post(url, data=f)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
log.info(f"Uploaded to /api/blobs") # DEBUG
# Remove local file
os.remove(file_path)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
# Create model in ollama
model_name, ext = os.path.splitext(file.filename)
log.info(f"Created Model: {model_name}") # DEBUG
create_payload = {
"model": model_name,
# Reference the file by its original name => the uploaded blob's digest
"files": {file.filename: f"sha256:{file_hash}"},
}
log.info(f"Model Payload: {create_payload}") # DEBUG
# Call ollama /api/create
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
create_resp = requests.post(
url=f"{ollama_url}/api/create",
headers={"Content-Type": "application/json"},
data=json.dumps(create_payload),
)
if create_resp.ok:
log.info(f"API SUCCESS!") # DEBUG
done_msg = {
"done": True,
"blob": f"sha256:{file_hash}",
"name": file.filename,
"model_created": model_name,
}
yield f"data: {json.dumps(done_msg)}\n\n"
else:
raise Exception(
f"Failed to create model in Ollama. {create_resp.text}"
)
else:
raise Exception("Ollama: Could not create blob, Please try again.")
except Exception as e:
res = {"error": str(e)}

View File

@@ -75,9 +75,9 @@ async def cleanup_response(
await session.close()
def openai_o1_handler(payload):
def openai_o1_o3_handler(payload):
"""
Handle O1 specific parameters
Handle o1, o3 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
@@ -621,10 +621,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-")
if is_o1:
payload = openai_o1_handler(payload)
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
if is_o1_o3:
payload = openai_o1_o3_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:

View File

@@ -392,6 +392,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@@ -441,6 +442,7 @@ class WebSearchConfig(BaseModel):
exa_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
domain_filter_list: Optional[List[str]] = []
class WebConfig(BaseModel):
@@ -553,6 +555,9 @@ async def update_rag_config(
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
form_data.web.search.domain_filter_list
)
return {
"status": True,
@@ -599,6 +604,7 @@ async def update_rag_config(
"exa_api_key": request.app.state.config.EXA_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}

View File

@@ -10,6 +10,7 @@ from open_webui.config import (
S3_ACCESS_KEY_ID,
S3_BUCKET_NAME,
S3_ENDPOINT_URL,
S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
GCS_BUCKET_NAME,
@@ -93,15 +94,17 @@ class S3StorageProvider(StorageProvider):
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
s3_key = os.path.join(self.key_prefix, filename)
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
"s3://" + self.bucket_name + "/" + s3_key,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
@@ -109,18 +112,18 @@ class S3StorageProvider(StorageProvider):
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
s3_key = self._extract_s3_key(file_path)
local_file_path = self._get_local_file_path(s3_key)
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1]
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
s3_key = self._extract_s3_key(file_path)
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
@@ -133,6 +136,10 @@ class S3StorageProvider(StorageProvider):
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
# Skip objects that were not uploaded from open-webui in the first place
if not content["Key"].startswith(self.key_prefix):
continue
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
@@ -142,6 +149,13 @@ class S3StorageProvider(StorageProvider):
# Always delete from local storage
LocalStorageProvider.delete_all_files()
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
def _extract_s3_key(self, full_file_path: str) -> str:
return "/".join(full_file_path.split("//")[1].split("/")[1:])
def _get_local_file_path(self, s3_key: str) -> str:
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
class GCSStorageProvider(StorageProvider):
def __init__(self):

View File

@@ -44,6 +44,10 @@ from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
@@ -177,116 +181,38 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
except Exception as e:
return Exception(f"Error: {e}")
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
metadata = {
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
}
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__request__": request,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
return Exception(f"Error: {e}")
return data
try:
result, _ = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="outlet",
form_data=data,
extra_params=extra_params,
)
return result
except Exception as e:
return Exception(f"Error: {e}")
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):

View File

@@ -0,0 +1,153 @@
import asyncio
import json
import uuid
import websockets
import requests
from urllib.parse import urljoin
async def execute_code_jupyter(
jupyter_url, code, token=None, password=None, timeout=10
):
"""
Executes Python code in a Jupyter kernel.
Supports authentication with a token or password.
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 10s)
:return: Dictionary with stdout, stderr, and result
"""
session = requests.Session() # Maintain cookies
headers = {} # Headers for requests
# Authenticate using password
if password and not token:
try:
login_url = urljoin(jupyter_url, "/login")
response = session.get(login_url)
response.raise_for_status()
# Retrieve `_xsrf` token
xsrf_token = session.cookies.get("_xsrf")
if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token")
# Send login request
login_data = {"_xsrf": xsrf_token, "password": password}
login_response = session.post(
login_url, data=login_data, cookies=session.cookies
)
login_response.raise_for_status()
# Update headers with `_xsrf`
headers["X-XSRFToken"] = xsrf_token
except Exception as e:
return {
"stdout": "",
"stderr": f"Authentication Error: {str(e)}",
"result": "",
}
# Construct API URLs with authentication token if provided
params = f"?token={token}" if token else ""
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
try:
# Include cookies if authenticating with password
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
response.raise_for_status()
kernel_id = response.json()["id"]
# Construct WebSocket URL
websocket_url = urljoin(
jupyter_url.replace("http", "ws"),
f"/api/kernels/{kernel_id}/channels{params}",
)
# **IMPORTANT:** Include authentication cookies for WebSockets
ws_headers = {}
if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
cookies = {name: value for name, value in session.cookies.items()}
ws_headers["Cookie"] = "; ".join(
[f"{name}={value}" for name, value in cookies.items()]
)
# Connect to the WebSocket
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
msg_id = str(uuid.uuid4())
# Send execution request
execute_request = {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
await ws.send(json.dumps(execute_request))
# Collect execution results
stdout, stderr, result = "", "", None
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
msg_type = message_data.get("msg_type")
if msg_type == "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
elif msg_type in ("execute_result", "display_data"):
result = message_data["content"]["data"].get(
"text/plain", ""
)
elif msg_type == "error":
stderr += "\n".join(message_data["content"]["traceback"])
elif (
msg_type == "status"
and message_data["content"]["execution_state"] == "idle"
):
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
except Exception as e:
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
finally:
# Shutdown the kernel
if kernel_id:
requests.delete(
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
)
return {
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"result": result.strip() if result else "",
}
# Example Usage
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token"))
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password"))

View File

@@ -0,0 +1,99 @@
import inspect
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.models.functions import Functions
def get_sorted_filter_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
filter_ids.sort(key=get_priority)
return filter_ids
async def process_filter_functions(
request, filter_ids, filter_type, form_data, extra_params
):
skip_files = None
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
# Apply valves to the function
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
# Prepare handler function
handler = getattr(function_module, filter_type, None)
if not handler:
continue
try:
# Prepare parameters
sig = inspect.signature(handler)
params = {"body": form_data} | {
k: v
for k, v in {
**extra_params,
"__id__": filter_id,
}.items()
if k in sig.parameters
}
# Handle user parameters
if "__user__" in sig.parameters:
if hasattr(function_module, "UserValves"):
try:
params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
)
except Exception as e:
print(e)
# Execute handler
if inspect.iscoroutinefunction(handler):
form_data = await handler(**params)
else:
form_data = handler(**params)
except Exception as e:
print(f"Error in {filter_type} handler {filter_id}: {e}")
raise e
# Handle file cleanup for inlet
if skip_files and "files" in form_data.get("metadata", {}):
del form_data["metadata"]["files"]
return form_data, {}

View File

@@ -161,7 +161,7 @@ async def comfyui_generate_image(
seed = (
payload.seed
if payload.seed
else random.randint(0, 18446744073709551614)
else random.randint(0, 1125899906842624)
)
for node_id in node.node_ids:
workflow[node_id]["inputs"][node.key] = seed

View File

@@ -68,7 +68,11 @@ from open_webui.utils.misc import (
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.tasks import create_task
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
skip_files = None
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
# Apply valves to the function
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if hasattr(function_module, "inlet"):
try:
inlet = function_module.inlet
# Create a dictionary of parameters to be passed to the function
params = {"body": body} | {
k: v
for k, v in {
**extra_params,
"__model__": model,
"__id__": filter_id,
}.items()
if k in inspect.signature(inlet).parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
try:
params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
)
except Exception as e:
print(e)
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]
return body, {}
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
@@ -572,13 +483,13 @@ async def chat_image_generation_handler(
{
"type": "status",
"data": {
"description": f"An error occured while generating an image",
"description": f"An error occurred while generating an image",
"done": True,
},
}
)
system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
if system_message_content:
form_data["messages"] = add_or_update_system_message(
@@ -706,6 +617,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
}
# Initialize events to store additional event to be sent to the client
@@ -782,8 +694,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
)
try:
form_data, flags = await chat_completion_filter_functions_handler(
request, form_data, model, extra_params
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
@@ -1122,6 +1038,20 @@ async def process_chat_response(
},
)
def split_content_and_whitespace(content):
content_stripped = content.rstrip()
original_whitespace = (
content[len(content_stripped) :]
if len(content) > len(content_stripped)
else ""
)
return content_stripped, original_whitespace
def is_opening_code_block(content):
backtick_segments = content.split("```")
# Even number of segments means the last backticks are opening a new block
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
# Handle as a background task
async def post_response_handler(response, events):
def serialize_content_blocks(content_blocks, raw=False):
@@ -1188,6 +1118,19 @@ async def process_chat_response(
output = block.get("output", None)
lang = attributes.get("lang", "")
content_stripped, original_whitespace = (
split_content_and_whitespace(content)
)
if is_opening_code_block(content_stripped):
# Remove trailing backticks that would open a new block
content = (
content_stripped.rstrip("`").rstrip()
+ original_whitespace
)
else:
# Keep content as is - either closing backticks or no backticks
content = content_stripped + original_whitespace
if output:
output = html.escape(json.dumps(output))
@@ -1242,10 +1185,10 @@ async def process_chat_response(
match.end() :
] # Content after opening tag
# Remove the start tag from the currently handling text block
# Remove the start tag and after from the currently handling text block
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].replace(match.group(0), "")
].replace(match.group(0) + after_tag, "")
if before_tag:
content_blocks[-1]["content"] = before_tag
@@ -1708,15 +1651,45 @@ async def process_chat_response(
output = ""
try:
if content_blocks[-1]["attributes"].get("type") == "code":
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": content_blocks[-1]["content"],
},
code = content_blocks[-1]["content"]
if (
request.app.state.config.CODE_INTERPRETER_ENGINE
== "pyodide"
):
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": code,
},
}
)
elif (
request.app.state.config.CODE_INTERPRETER_ENGINE
== "jupyter"
):
output = await execute_code_jupyter(
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
code,
(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
== "token"
else None
),
(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
== "password"
else None
),
)
else:
output = {
"stdout": "Code interpreter engine not configured."
}
)
if isinstance(output, dict):
stdout = output.get("stdout", "")

View File

@@ -244,11 +244,12 @@ def get_gravatar_url(email):
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
def calculate_sha256(file):
def calculate_sha256(file_path, chunk_size):
# Compute SHA-256 hash of a file efficiently in chunks
sha256 = hashlib.sha256()
# Read the file in chunks to efficiently handle large files
for chunk in iter(lambda: file.read(8192), b""):
sha256.update(chunk)
with open(file_path, "rb") as f:
while chunk := f.read(chunk_size):
sha256.update(chunk)
return sha256.hexdigest()

View File

@@ -1,6 +1,7 @@
import base64
import logging
import mimetypes
import sys
import uuid
import aiohttp
@@ -40,7 +41,11 @@ from open_webui.utils.misc import parse_duration
from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
@@ -72,12 +77,15 @@ class OAuthManager:
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
log.debug("Assigning the only user the admin role")
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
log.debug("Assigning the first user the admin role")
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
log.debug("Running OAUTH Role management")
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
@@ -93,17 +101,24 @@ class OAuthManager:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
log.debug(f"Oauth Roles claim: {oauth_claim}")
log.debug(f"User roles from oauth: {oauth_roles}")
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
log.debug("Assigned user the user role")
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
log.debug("Assigned user the admin role")
role = "admin"
break
else:
@@ -117,16 +132,27 @@ class OAuthManager:
return role
def update_user_groups(self, user, user_data, default_permissions):
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
log.debug(f"Oauth Groups claim: {oauth_claim}")
log.debug(f"User oauth groups: {user_oauth_groups}")
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
log.debug(
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
)
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
if group_model.name not in user_oauth_groups:
# Remove group from user
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
)
user_ids = group_model.user_ids
user_ids = [i for i in user_ids if i != user.id]
@@ -152,6 +178,9 @@ class OAuthManager:
gm.name == group_model.name for gm in user_current_groups
):
# Add user to group
log.debug(
f"Adding user to group {group_model.name} as it was found in their oauth groups"
)
user_ids = group_model.user_ids
user_ids.append(user.id)
@@ -193,7 +222,7 @@ class OAuthManager:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token.get("userinfo")
if not user_data:
if not user_data or "email" not in user_data:
user_data: UserInfo = await client.userinfo(token=token)
if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}")
@@ -261,15 +290,20 @@ class OAuthManager:
}
async with aiohttp.ClientSession() as session:
async with session.get(picture_url, **get_kwargs) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
if resp.ok:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(
picture_url
)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
else:
picture_url = "/user.png"
except Exception as e:
log.error(
f"Error downloading profile image '{picture_url}': {e}"

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List
from html import escape
from markdown import markdown
@@ -41,13 +42,13 @@ class PDFGenerator:
def _build_html_message(self, message: Dict[str, Any]) -> str:
"""Build HTML for a single message."""
role = message.get("role", "user")
content = message.get("content", "")
role = escape(message.get("role", "user"))
content = escape(message.get("content", ""))
timestamp = message.get("timestamp")
model = message.get("model") if role == "assistant" else ""
model = escape(message.get("model") if role == "assistant" else "")
date_str = self.format_timestamp(timestamp) if timestamp else ""
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
# extends pymdownx extension to convert markdown to html.
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
@@ -76,6 +77,7 @@ class PDFGenerator:
def _generate_html_body(self) -> str:
"""Generate the full HTML body for the PDF."""
escaped_title = escape(self.form_data.title)
return f"""
<html>
<head>
@@ -84,7 +86,7 @@ class PDFGenerator:
<body>
<div>
<div>
<h2>{self.form_data.title}</h2>
<h2>{escaped_title}</h2>
{self.messages_html}
</div>
</div>