mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into bug/user-signup/fix-oauth-username-claim-has-no-effect
This commit is contained in:
@@ -660,6 +660,7 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
|
||||
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
|
||||
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
|
||||
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
|
||||
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
|
||||
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
|
||||
|
||||
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
|
||||
@@ -1325,6 +1326,48 @@ Your task is to synthesize these responses into a single, high-quality response.
|
||||
Responses from models: {{responses}}"""
|
||||
|
||||
|
||||
####################################
|
||||
# Code Interpreter
|
||||
####################################
|
||||
|
||||
ENABLE_CODE_INTERPRETER = PersistentConfig(
|
||||
"ENABLE_CODE_INTERPRETER",
|
||||
"code_interpreter.enable",
|
||||
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_ENGINE = PersistentConfig(
|
||||
"CODE_INTERPRETER_ENGINE",
|
||||
"code_interpreter.engine",
|
||||
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_URL",
|
||||
"code_interpreter.jupyter.url",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||
"code_interpreter.jupyter.auth",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||
"code_interpreter.jupyter.auth_token",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||
"code_interpreter.jupyter.auth_password",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||
#### Tools Available
|
||||
|
||||
@@ -1645,7 +1688,7 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
||||
# This ensures the highest level of safety and reliability of the information sources.
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
||||
"rag.rag.web.search.domain.filter_list",
|
||||
"rag.web.search.domain.filter_list",
|
||||
[
|
||||
# "wikipedia.com",
|
||||
# "wikimedia.org",
|
||||
@@ -2012,6 +2055,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
|
||||
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Add Deepgram configuration
|
||||
DEEPGRAM_API_KEY = PersistentConfig(
|
||||
"DEEPGRAM_API_KEY",
|
||||
"audio.stt.deepgram.api_key",
|
||||
os.getenv("DEEPGRAM_API_KEY", ""),
|
||||
)
|
||||
|
||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||
|
||||
@@ -92,6 +92,7 @@ log_sources = [
|
||||
"RAG",
|
||||
"WEBHOOK",
|
||||
"SOCKET",
|
||||
"OAUTH",
|
||||
]
|
||||
|
||||
SRC_LOG_LEVELS = {}
|
||||
|
||||
@@ -97,6 +97,13 @@ from open_webui.config import (
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
OPENAI_API_CONFIGS,
|
||||
# Code Interpreter
|
||||
ENABLE_CODE_INTERPRETER,
|
||||
CODE_INTERPRETER_ENGINE,
|
||||
CODE_INTERPRETER_JUPYTER_URL,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
# Image
|
||||
AUTOMATIC1111_API_AUTH,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
@@ -130,6 +137,7 @@ from open_webui.config import (
|
||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
WHISPER_MODEL,
|
||||
DEEPGRAM_API_KEY,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
# Retrieval
|
||||
@@ -569,6 +577,23 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE INTERPRETER
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
|
||||
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
|
||||
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
@@ -611,6 +636,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
@@ -753,6 +779,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
|
||||
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
|
||||
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
||||
|
||||
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
|
||||
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
|
||||
|
||||
@@ -1013,12 +1040,14 @@ async def get_app_config(request: Request):
|
||||
{
|
||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
}
|
||||
if user is not None
|
||||
else {}
|
||||
|
||||
@@ -470,7 +470,7 @@ class ChatTable:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# it is possible that the shared link was deleted. hence,
|
||||
# we check if the chat is still shared by checkng if a chat with the share_id exists
|
||||
# we check if the chat is still shared by checking if a chat with the share_id exists
|
||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||
|
||||
if chat:
|
||||
|
||||
@@ -113,6 +113,34 @@ class OpenSearchClient:
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}",
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_or_create_index(self, index_name: str, dimension: int):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension)
|
||||
|
||||
@@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
jina_search_endpoint = "https://s.jina.ai/"
|
||||
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
url = str(URL(jina_search_endpoint + query))
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": api_key,
|
||||
"X-Retain-Images": "none",
|
||||
}
|
||||
|
||||
payload = {"q": query, "count": count if count <= 10 else 10}
|
||||
|
||||
url = str(URL(jina_search_endpoint))
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["data"][:count]:
|
||||
for result in data["data"]:
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import requests
|
||||
import mimetypes
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
@@ -138,6 +139,7 @@ class STTConfigForm(BaseModel):
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
DEEPGRAM_API_KEY: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
@@ -165,6 +167,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -190,6 +193,7 @@ async def update_audio_config(
|
||||
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -214,6 +218,7 @@ async def update_audio_config(
|
||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -521,6 +526,69 @@ def transcribe(request: Request, file_path):
|
||||
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
elif request.app.state.config.STT_ENGINE == "deepgram":
|
||||
try:
|
||||
# Determine the MIME type of the file
|
||||
mime, _ = mimetypes.guess_type(file_path)
|
||||
if not mime:
|
||||
mime = "audio/wav" # fallback to wav if undetectable
|
||||
|
||||
# Read the audio file
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Build headers and parameters
|
||||
headers = {
|
||||
"Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
|
||||
"Content-Type": mime,
|
||||
}
|
||||
|
||||
# Add model if specified
|
||||
params = {}
|
||||
if request.app.state.config.STT_MODEL:
|
||||
params["model"] = request.app.state.config.STT_MODEL
|
||||
|
||||
# Make request to Deepgram API
|
||||
r = requests.post(
|
||||
"https://api.deepgram.com/v1/listen",
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=file_data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
response_data = r.json()
|
||||
|
||||
# Extract transcript from Deepgram response
|
||||
try:
|
||||
transcript = response_data["results"]["channels"][0]["alternatives"][
|
||||
0
|
||||
].get("transcript", "")
|
||||
except (KeyError, IndexError) as e:
|
||||
log.error(f"Malformed response from Deepgram: {str(e)}")
|
||||
raise Exception(
|
||||
"Failed to parse Deepgram response - unexpected response format"
|
||||
)
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# Save transcript
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
|
||||
@@ -36,6 +36,61 @@ async def export_config(user=Depends(get_admin_user)):
|
||||
return get_config()
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
|
||||
|
||||
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_interpreter_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_URL
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
|
||||
return {
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
|
||||
@@ -11,10 +11,8 @@ import re
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
|
||||
import requests
|
||||
|
||||
from fastapi import (
|
||||
@@ -990,6 +988,8 @@ async def generate_chat_completion(
|
||||
)
|
||||
|
||||
payload = {**form_data.model_dump(exclude_none=True)}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
@@ -1408,9 +1408,10 @@ async def download_model(
|
||||
return None
|
||||
|
||||
|
||||
# TODO: Progress bar does not reflect size & duration of upload.
|
||||
@router.post("/models/upload")
|
||||
@router.post("/models/upload/{url_idx}")
|
||||
def upload_model(
|
||||
async def upload_model(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
url_idx: Optional[int] = None,
|
||||
@@ -1419,59 +1420,85 @@ def upload_model(
|
||||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||
# --- P1: save file locally ---
|
||||
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
||||
with open(file_path, "wb") as out_f:
|
||||
while True:
|
||||
chunk = file.file.read(chunk_size)
|
||||
# log.info(f"Chunk: {str(chunk)}") # DEBUG
|
||||
if not chunk:
|
||||
break
|
||||
out_f.write(chunk)
|
||||
|
||||
# Save file in chunks
|
||||
with open(file_path, "wb+") as f:
|
||||
for chunk in file.file:
|
||||
f.write(chunk)
|
||||
|
||||
def file_process_stream():
|
||||
async def file_process_stream():
|
||||
nonlocal ollama_url
|
||||
total_size = os.path.getsize(file_path)
|
||||
chunk_size = 1024 * 1024
|
||||
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
|
||||
|
||||
# --- P2: SSE progress + calculate sha256 hash ---
|
||||
file_hash = calculate_sha256(file_path, chunk_size)
|
||||
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
total = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
done = True
|
||||
continue
|
||||
|
||||
total += len(chunk)
|
||||
progress = round((total / total_size) * 100, 2)
|
||||
|
||||
res = {
|
||||
bytes_read = 0
|
||||
while chunk := f.read(chunk_size):
|
||||
bytes_read += len(chunk)
|
||||
progress = round(bytes_read / total_size * 100, 2)
|
||||
data_msg = {
|
||||
"progress": progress,
|
||||
"total": total_size,
|
||||
"completed": total,
|
||||
"completed": bytes_read,
|
||||
}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
yield f"data: {json.dumps(data_msg)}\n\n"
|
||||
|
||||
if done:
|
||||
f.seek(0)
|
||||
hashed = calculate_sha256(f)
|
||||
f.seek(0)
|
||||
# --- P3: Upload to ollama /api/blobs ---
|
||||
with open(file_path, "rb") as f:
|
||||
url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
|
||||
response = requests.post(url, data=f)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=f)
|
||||
if response.ok:
|
||||
log.info(f"Uploaded to /api/blobs") # DEBUG
|
||||
# Remove local file
|
||||
os.remove(file_path)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file.filename,
|
||||
}
|
||||
os.remove(file_path)
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
"Ollama: Could not create blob, Please try again."
|
||||
)
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
# Call ollama /api/create
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
||||
create_resp = requests.post(
|
||||
url=f"{ollama_url}/api/create",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(create_payload),
|
||||
)
|
||||
|
||||
if create_resp.ok:
|
||||
log.info(f"API SUCCESS!") # DEBUG
|
||||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
"name": file.filename,
|
||||
"model_created": model_name,
|
||||
}
|
||||
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to create model in Ollama. {create_resp.text}"
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Ollama: Could not create blob, Please try again.")
|
||||
|
||||
except Exception as e:
|
||||
res = {"error": str(e)}
|
||||
|
||||
@@ -75,9 +75,9 @@ async def cleanup_response(
|
||||
await session.close()
|
||||
|
||||
|
||||
def openai_o1_handler(payload):
|
||||
def openai_o1_o3_handler(payload):
|
||||
"""
|
||||
Handle O1 specific parameters
|
||||
Handle o1, o3 specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Remove "max_tokens" from the payload
|
||||
@@ -621,10 +621,10 @@ async def generate_chat_completion(
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1 = payload["model"].lower().startswith("o1-")
|
||||
if is_o1:
|
||||
payload = openai_o1_handler(payload)
|
||||
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||
if is_o1_o3:
|
||||
payload = openai_o1_o3_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
|
||||
@@ -392,6 +392,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -441,6 +442,7 @@ class WebSearchConfig(BaseModel):
|
||||
exa_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
domain_filter_list: Optional[List[str]] = []
|
||||
|
||||
|
||||
class WebConfig(BaseModel):
|
||||
@@ -553,6 +555,9 @@ async def update_rag_config(
|
||||
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.search.concurrent_requests
|
||||
)
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||||
form_data.web.search.domain_filter_list
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -599,6 +604,7 @@ async def update_rag_config(
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ from open_webui.config import (
|
||||
S3_ACCESS_KEY_ID,
|
||||
S3_BUCKET_NAME,
|
||||
S3_ENDPOINT_URL,
|
||||
S3_KEY_PREFIX,
|
||||
S3_REGION_NAME,
|
||||
S3_SECRET_ACCESS_KEY,
|
||||
GCS_BUCKET_NAME,
|
||||
@@ -93,15 +94,17 @@ class S3StorageProvider(StorageProvider):
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
)
|
||||
self.bucket_name = S3_BUCKET_NAME
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to S3 storage."""
|
||||
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
try:
|
||||
self.s3_client.upload_file(file_path, self.bucket_name, filename)
|
||||
s3_key = os.path.join(self.key_prefix, filename)
|
||||
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
||||
return (
|
||||
open(file_path, "rb").read(),
|
||||
"s3://" + self.bucket_name + "/" + filename,
|
||||
"s3://" + self.bucket_name + "/" + s3_key,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error uploading file to S3: {e}")
|
||||
@@ -109,18 +112,18 @@ class S3StorageProvider(StorageProvider):
|
||||
def get_file(self, file_path: str) -> str:
|
||||
"""Handles downloading of the file from S3 storage."""
|
||||
try:
|
||||
bucket_name, key = file_path.split("//")[1].split("/")
|
||||
local_file_path = f"{UPLOAD_DIR}/{key}"
|
||||
self.s3_client.download_file(bucket_name, key, local_file_path)
|
||||
s3_key = self._extract_s3_key(file_path)
|
||||
local_file_path = self._get_local_file_path(s3_key)
|
||||
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
|
||||
return local_file_path
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error downloading file from S3: {e}")
|
||||
|
||||
def delete_file(self, file_path: str) -> None:
|
||||
"""Handles deletion of the file from S3 storage."""
|
||||
filename = file_path.split("/")[-1]
|
||||
try:
|
||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
s3_key = self._extract_s3_key(file_path)
|
||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error deleting file from S3: {e}")
|
||||
|
||||
@@ -133,6 +136,10 @@ class S3StorageProvider(StorageProvider):
|
||||
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
|
||||
if "Contents" in response:
|
||||
for content in response["Contents"]:
|
||||
# Skip objects that were not uploaded from open-webui in the first place
|
||||
if not content["Key"].startswith(self.key_prefix):
|
||||
continue
|
||||
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket_name, Key=content["Key"]
|
||||
)
|
||||
@@ -142,6 +149,13 @@ class S3StorageProvider(StorageProvider):
|
||||
# Always delete from local storage
|
||||
LocalStorageProvider.delete_all_files()
|
||||
|
||||
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
|
||||
def _extract_s3_key(self, full_file_path: str) -> str:
|
||||
return "/".join(full_file_path.split("//")[1].split("/")[1:])
|
||||
|
||||
def _get_local_file_path(self, s3_key: str) -> str:
|
||||
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
|
||||
|
||||
|
||||
class GCSStorageProvider(StorageProvider):
|
||||
def __init__(self):
|
||||
|
||||
@@ -44,6 +44,10 @@ from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
|
||||
@@ -177,116 +181,38 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
__event_emitter__ = get_event_emitter(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
metadata = {
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
|
||||
__event_call__ = get_event_call(
|
||||
{
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
)
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
# Sort filter_ids by priority, using the get_priority function
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if not hasattr(function_module, "outlet"):
|
||||
continue
|
||||
try:
|
||||
outlet = function_module.outlet
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(outlet)
|
||||
params = {"body": data}
|
||||
|
||||
# Extra parameters to be passed to the function
|
||||
extra_params = {
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__request__": request,
|
||||
}
|
||||
|
||||
# Add extra params in contained in function signature
|
||||
for key, value in extra_params.items():
|
||||
if key in sig.parameters:
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, user.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
params = {**params, "__user__": __user__}
|
||||
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(**params)
|
||||
else:
|
||||
data = outlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
return data
|
||||
try:
|
||||
result, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="outlet",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
|
||||
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
|
||||
|
||||
153
backend/open_webui/utils/code_interpreter.py
Normal file
153
backend/open_webui/utils/code_interpreter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import websockets
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
async def execute_code_jupyter(
|
||||
jupyter_url, code, token=None, password=None, timeout=10
|
||||
):
|
||||
"""
|
||||
Executes Python code in a Jupyter kernel.
|
||||
Supports authentication with a token or password.
|
||||
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
|
||||
:param code: Code to execute
|
||||
:param token: Jupyter authentication token (optional)
|
||||
:param password: Jupyter password (optional)
|
||||
:param timeout: WebSocket timeout in seconds (default: 10s)
|
||||
:return: Dictionary with stdout, stderr, and result
|
||||
"""
|
||||
session = requests.Session() # Maintain cookies
|
||||
headers = {} # Headers for requests
|
||||
|
||||
# Authenticate using password
|
||||
if password and not token:
|
||||
try:
|
||||
login_url = urljoin(jupyter_url, "/login")
|
||||
response = session.get(login_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Retrieve `_xsrf` token
|
||||
xsrf_token = session.cookies.get("_xsrf")
|
||||
if not xsrf_token:
|
||||
raise ValueError("Failed to fetch _xsrf token")
|
||||
|
||||
# Send login request
|
||||
login_data = {"_xsrf": xsrf_token, "password": password}
|
||||
login_response = session.post(
|
||||
login_url, data=login_data, cookies=session.cookies
|
||||
)
|
||||
login_response.raise_for_status()
|
||||
|
||||
# Update headers with `_xsrf`
|
||||
headers["X-XSRFToken"] = xsrf_token
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Authentication Error: {str(e)}",
|
||||
"result": "",
|
||||
}
|
||||
|
||||
# Construct API URLs with authentication token if provided
|
||||
params = f"?token={token}" if token else ""
|
||||
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
|
||||
|
||||
try:
|
||||
# Include cookies if authenticating with password
|
||||
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
|
||||
response.raise_for_status()
|
||||
kernel_id = response.json()["id"]
|
||||
|
||||
# Construct WebSocket URL
|
||||
websocket_url = urljoin(
|
||||
jupyter_url.replace("http", "ws"),
|
||||
f"/api/kernels/{kernel_id}/channels{params}",
|
||||
)
|
||||
|
||||
# **IMPORTANT:** Include authentication cookies for WebSockets
|
||||
ws_headers = {}
|
||||
if password and not token:
|
||||
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
|
||||
cookies = {name: value for name, value in session.cookies.items()}
|
||||
ws_headers["Cookie"] = "; ".join(
|
||||
[f"{name}={value}" for name, value in cookies.items()]
|
||||
)
|
||||
|
||||
# Connect to the WebSocket
|
||||
async with websockets.connect(
|
||||
websocket_url, additional_headers=ws_headers
|
||||
) as ws:
|
||||
msg_id = str(uuid.uuid4())
|
||||
|
||||
# Send execution request
|
||||
execute_request = {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "user",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": "",
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"channel": "shell",
|
||||
}
|
||||
await ws.send(json.dumps(execute_request))
|
||||
|
||||
# Collect execution results
|
||||
stdout, stderr, result = "", "", None
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(ws.recv(), timeout)
|
||||
message_data = json.loads(message)
|
||||
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
|
||||
msg_type = message_data.get("msg_type")
|
||||
if msg_type == "stream":
|
||||
if message_data["content"]["name"] == "stdout":
|
||||
stdout += message_data["content"]["text"]
|
||||
elif message_data["content"]["name"] == "stderr":
|
||||
stderr += message_data["content"]["text"]
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
result = message_data["content"]["data"].get(
|
||||
"text/plain", ""
|
||||
)
|
||||
elif msg_type == "error":
|
||||
stderr += "\n".join(message_data["content"]["traceback"])
|
||||
elif (
|
||||
msg_type == "status"
|
||||
and message_data["content"]["execution_state"] == "idle"
|
||||
):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
stderr += "\nExecution timed out."
|
||||
break
|
||||
except Exception as e:
|
||||
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
|
||||
finally:
|
||||
# Shutdown the kernel
|
||||
if kernel_id:
|
||||
requests.delete(
|
||||
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
|
||||
)
|
||||
|
||||
return {
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr.strip(),
|
||||
"result": result.strip() if result else "",
|
||||
}
|
||||
|
||||
|
||||
# Example Usage
|
||||
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token"))
|
||||
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password"))
|
||||
99
backend/open_webui/utils/filter.py
Normal file
99
backend/open_webui/utils/filter.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import inspect
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.models.functions import Functions
|
||||
|
||||
|
||||
def get_sorted_filter_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel to include vavles
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
|
||||
async def process_filter_functions(
|
||||
request, filter_ids, filter_type, form_data, extra_params
|
||||
):
|
||||
skip_files = None
|
||||
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
# Prepare handler function
|
||||
handler = getattr(function_module, filter_type, None)
|
||||
if not handler:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Prepare parameters
|
||||
sig = inspect.signature(handler)
|
||||
params = {"body": form_data} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in sig.parameters
|
||||
}
|
||||
|
||||
# Handle user parameters
|
||||
if "__user__" in sig.parameters:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Execute handler
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
form_data = await handler(**params)
|
||||
else:
|
||||
form_data = handler(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in {filter_type} handler {filter_id}: {e}")
|
||||
raise e
|
||||
|
||||
# Handle file cleanup for inlet
|
||||
if skip_files and "files" in form_data.get("metadata", {}):
|
||||
del form_data["metadata"]["files"]
|
||||
|
||||
return form_data, {}
|
||||
@@ -161,7 +161,7 @@ async def comfyui_generate_image(
|
||||
seed = (
|
||||
payload.seed
|
||||
if payload.seed
|
||||
else random.randint(0, 18446744073709551614)
|
||||
else random.randint(0, 1125899906842624)
|
||||
)
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"][node.key] = seed
|
||||
|
||||
@@ -68,7 +68,11 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if hasattr(function_module, "inlet"):
|
||||
try:
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Create a dictionary of parameters to be passed to the function
|
||||
params = {"body": body} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in inspect.signature(inlet).parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
else:
|
||||
body = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
@@ -572,13 +483,13 @@ async def chat_image_generation_handler(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"description": f"An error occured while generating an image",
|
||||
"description": f"An error occurred while generating an image",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
||||
|
||||
if system_message_content:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
@@ -706,6 +617,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
@@ -782,8 +694,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
form_data, flags = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="inlet",
|
||||
form_data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
@@ -1122,6 +1038,20 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
def split_content_and_whitespace(content):
|
||||
content_stripped = content.rstrip()
|
||||
original_whitespace = (
|
||||
content[len(content_stripped) :]
|
||||
if len(content) > len(content_stripped)
|
||||
else ""
|
||||
)
|
||||
return content_stripped, original_whitespace
|
||||
|
||||
def is_opening_code_block(content):
|
||||
backtick_segments = content.split("```")
|
||||
# Even number of segments means the last backticks are opening a new block
|
||||
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
||||
|
||||
# Handle as a background task
|
||||
async def post_response_handler(response, events):
|
||||
def serialize_content_blocks(content_blocks, raw=False):
|
||||
@@ -1188,6 +1118,19 @@ async def process_chat_response(
|
||||
output = block.get("output", None)
|
||||
lang = attributes.get("lang", "")
|
||||
|
||||
content_stripped, original_whitespace = (
|
||||
split_content_and_whitespace(content)
|
||||
)
|
||||
if is_opening_code_block(content_stripped):
|
||||
# Remove trailing backticks that would open a new block
|
||||
content = (
|
||||
content_stripped.rstrip("`").rstrip()
|
||||
+ original_whitespace
|
||||
)
|
||||
else:
|
||||
# Keep content as is - either closing backticks or no backticks
|
||||
content = content_stripped + original_whitespace
|
||||
|
||||
if output:
|
||||
output = html.escape(json.dumps(output))
|
||||
|
||||
@@ -1242,10 +1185,10 @@ async def process_chat_response(
|
||||
match.end() :
|
||||
] # Content after opening tag
|
||||
|
||||
# Remove the start tag from the currently handling text block
|
||||
# Remove the start tag and after from the currently handling text block
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].replace(match.group(0), "")
|
||||
].replace(match.group(0) + after_tag, "")
|
||||
|
||||
if before_tag:
|
||||
content_blocks[-1]["content"] = before_tag
|
||||
@@ -1708,15 +1651,45 @@ async def process_chat_response(
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
code = content_blocks[-1]["content"]
|
||||
|
||||
if (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
== "pyodide"
|
||||
):
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
== "jupyter"
|
||||
):
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
code,
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
== "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
== "password"
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
output = {
|
||||
"stdout": "Code interpreter engine not configured."
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
|
||||
@@ -244,11 +244,12 @@ def get_gravatar_url(email):
|
||||
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
|
||||
|
||||
|
||||
def calculate_sha256(file):
|
||||
def calculate_sha256(file_path, chunk_size):
|
||||
# Compute SHA-256 hash of a file efficiently in chunks
|
||||
sha256 = hashlib.sha256()
|
||||
# Read the file in chunks to efficiently handle large files
|
||||
for chunk in iter(lambda: file.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
@@ -40,7 +41,11 @@ from open_webui.utils.misc import parse_duration
|
||||
from open_webui.utils.auth import get_password_hash, create_token
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||
|
||||
auth_manager_config = AppConfig()
|
||||
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
@@ -72,12 +77,15 @@ class OAuthManager:
|
||||
def get_user_role(self, user, user_data):
|
||||
if user and Users.get_num_users() == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
log.debug("Assigning the only user the admin role")
|
||||
return "admin"
|
||||
if not user and Users.get_num_users() == 0:
|
||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||
log.debug("Assigning the first user the admin role")
|
||||
return "admin"
|
||||
|
||||
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||
log.debug("Running OAUTH Role management")
|
||||
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||
@@ -93,17 +101,24 @@ class OAuthManager:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||
|
||||
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||
log.debug(f"User roles from oauth: {oauth_roles}")
|
||||
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
|
||||
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
|
||||
|
||||
# If any roles are found, check if they match the allowed or admin roles
|
||||
if oauth_roles:
|
||||
# If role management is enabled, and matching roles are provided, use the roles
|
||||
for allowed_role in oauth_allowed_roles:
|
||||
# If the user has any of the allowed roles, assign the role "user"
|
||||
if allowed_role in oauth_roles:
|
||||
log.debug("Assigned user the user role")
|
||||
role = "user"
|
||||
break
|
||||
for admin_role in oauth_admin_roles:
|
||||
# If the user has any of the admin roles, assign the role "admin"
|
||||
if admin_role in oauth_roles:
|
||||
log.debug("Assigned user the admin role")
|
||||
role = "admin"
|
||||
break
|
||||
else:
|
||||
@@ -117,16 +132,27 @@ class OAuthManager:
|
||||
return role
|
||||
|
||||
def update_user_groups(self, user, user_data, default_permissions):
|
||||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
log.debug(
|
||||
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||
)
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if group_model.name not in user_oauth_groups:
|
||||
# Remove group from user
|
||||
log.debug(
|
||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
@@ -152,6 +178,9 @@ class OAuthManager:
|
||||
gm.name == group_model.name for gm in user_current_groups
|
||||
):
|
||||
# Add user to group
|
||||
log.debug(
|
||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
@@ -193,7 +222,7 @@ class OAuthManager:
|
||||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token.get("userinfo")
|
||||
if not user_data:
|
||||
if not user_data or "email" not in user_data:
|
||||
user_data: UserInfo = await client.userinfo(token=token)
|
||||
if not user_data:
|
||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||
@@ -261,15 +290,20 @@ class OAuthManager:
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url, **get_kwargs) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
if resp.ok:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(
|
||||
picture_url
|
||||
)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
else:
|
||||
picture_url = "/user.png"
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error downloading profile image '{picture_url}': {e}"
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from html import escape
|
||||
|
||||
from markdown import markdown
|
||||
|
||||
@@ -41,13 +42,13 @@ class PDFGenerator:
|
||||
|
||||
def _build_html_message(self, message: Dict[str, Any]) -> str:
|
||||
"""Build HTML for a single message."""
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
role = escape(message.get("role", "user"))
|
||||
content = escape(message.get("content", ""))
|
||||
timestamp = message.get("timestamp")
|
||||
|
||||
model = message.get("model") if role == "assistant" else ""
|
||||
model = escape(message.get("model") if role == "assistant" else "")
|
||||
|
||||
date_str = self.format_timestamp(timestamp) if timestamp else ""
|
||||
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
|
||||
|
||||
# extends pymdownx extension to convert markdown to html.
|
||||
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
|
||||
@@ -76,6 +77,7 @@ class PDFGenerator:
|
||||
|
||||
def _generate_html_body(self) -> str:
|
||||
"""Generate the full HTML body for the PDF."""
|
||||
escaped_title = escape(self.form_data.title)
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
@@ -84,7 +86,7 @@ class PDFGenerator:
|
||||
<body>
|
||||
<div>
|
||||
<div>
|
||||
<h2>{self.form_data.title}</h2>
|
||||
<h2>{escaped_title}</h2>
|
||||
{self.messages_html}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user