Merge remote-tracking branch 'upstream/dev' into playwright

This commit is contained in:
Rory
2025-02-12 22:32:44 -06:00
120 changed files with 3485 additions and 1184 deletions

View File

@@ -683,6 +683,17 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# DIRECT CONNECTIONS
####################################
ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
"ENABLE_DIRECT_CONNECTIONS",
"direct.enable",
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
)
####################################
# OLLAMA_BASE_URL
####################################
@@ -1326,6 +1337,54 @@ Your task is to synthesize these responses into a single, high-quality response.
Responses from models: {{responses}}"""
####################################
# Code Interpreter
####################################
ENABLE_CODE_INTERPRETER = PersistentConfig(
"ENABLE_CODE_INTERPRETER",
"code_interpreter.enable",
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
)
CODE_INTERPRETER_ENGINE = PersistentConfig(
"CODE_INTERPRETER_ENGINE",
"code_interpreter.engine",
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
)
CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
"CODE_INTERPRETER_PROMPT_TEMPLATE",
"code_interpreter.prompt_template",
os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
)
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_URL",
"code_interpreter.jupyter.url",
os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
)
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH",
"code_interpreter.jupyter.auth",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
)
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
"code_interpreter.jupyter.auth_token",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
)
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
"code_interpreter.jupyter.auth_password",
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
)
DEFAULT_CODE_INTERPRETER_PROMPT = """
#### Tools Available
@@ -1336,9 +1395,8 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
- If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.
- **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
- **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
@@ -1691,6 +1749,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
)
BOCHA_SEARCH_API_KEY = PersistentConfig(
"BOCHA_SEARCH_API_KEY",
"rag.web.search.bocha_search_api_key",
os.getenv("BOCHA_SEARCH_API_KEY", ""),
)
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",

View File

@@ -97,6 +97,16 @@ from open_webui.config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
# Code Interpreter
ENABLE_CODE_INTERPRETER,
CODE_INTERPRETER_ENGINE,
CODE_INTERPRETER_PROMPT_TEMPLATE,
CODE_INTERPRETER_JUPYTER_URL,
CODE_INTERPRETER_JUPYTER_AUTH,
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
# Image
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
@@ -183,6 +193,7 @@ from open_webui.config import (
EXA_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
BOCHA_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID,
@@ -325,7 +336,11 @@ class SPAStaticFiles(StaticFiles):
return await super().get_response(path, scope)
except (HTTPException, StarletteHTTPException) as ex:
if ex.status_code == 404:
return await super().get_response("index.html", scope)
if path.endswith(".js"):
# Return 404 for javascript files
raise ex
else:
return await super().get_response("index.html", scope)
else:
raise ex
@@ -392,6 +407,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
app.state.OPENAI_MODELS = {}
########################################
#
# DIRECT CONNECTIONS
#
########################################
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
########################################
#
# WEBUI
@@ -517,6 +540,7 @@ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
@@ -574,6 +598,24 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
########################################
#
# CODE INTERPRETER
#
########################################
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
)
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
########################################
#
@@ -759,6 +801,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
@@ -1017,15 +1060,17 @@ async def get_app_config(request: Request):
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
**(
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
}
if user is not None
else {}

View File

@@ -271,6 +271,24 @@ class UsersTable:
except Exception:
return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
with get_db() as db:
user_settings = db.query(User).filter_by(id=id).first().settings
if user_settings is None:
user_settings = {}
user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def delete_user_by_id(self, id: str) -> bool:
try:
# Remove User from Groups

View File

@@ -120,18 +120,12 @@ class OpenSearchClient:
return None
query_body = {
"query": {
"bool": {
"filter": []
}
},
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({
"term": {field: value}
})
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
size = limit if limit else 10
@@ -139,7 +133,7 @@ class OpenSearchClient:
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}",
body=query_body,
size=size
size=size,
)
return self._result_to_get_result(result)

View File

@@ -0,0 +1,72 @@
import logging
from typing import Optional
import requests
import json
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _parse_response(response):
result = {}
if "data" in response:
data = response["data"]
if "webPages" in data:
webPages = data["webPages"]
if "value" in webPages:
result["webpage"] = [
{
"id": item.get("id", ""),
"name": item.get("name", ""),
"url": item.get("url", ""),
"snippet": item.get("snippet", ""),
"summary": item.get("summary", ""),
"siteName": item.get("siteName", ""),
"siteIcon": item.get("siteIcon", ""),
"datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""),
}
for item in webPages["value"]
]
return result
def search_bocha(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Bocha Search API key
query (str): The query to search for
"""
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = json.dumps({
"query": query,
"summary": True,
"freshness": "noLimit",
"count": count
})
response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status()
results = _parse_response(response.json())
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("name"),
snippet=result.get("summary")
)
for result in results.get("webpage", [])[:count]
]

View File

@@ -8,7 +8,6 @@ from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str,
search_engine_id: str,
@@ -17,34 +16,51 @@ def search_google_pse(
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
Handles pagination for counts greater than 10.
Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
Returns:
list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
headers = {"Content-Type": "application/json"}
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": count,
}
all_results = []
start_index = 1 # Google PSE start parameter is 1-based
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
while count > 0:
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": num_results_this_page,
"start": start_index,
}
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
if results: # check if results are returned. If not, no more pages to fetch.
all_results.extend(results)
count -= len(results) # Decrement count by the number of results fetched in this page.
start_index += 10 # Increment start index for the next page
else:
break # No more results from Google PSE, break the loop
json_response = response.json()
results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
all_results = get_filtered_results(all_results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results
for result in all_results
]

View File

@@ -25,13 +25,10 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": api_key,
"X-Retain-Images": "none"
"X-Retain-Images": "none",
}
payload = {
"q": query,
"count": count if count <= 10 else 10
}
payload = {"q": query, "count": count if count <= 10 else 10}
url = str(URL(jina_search_endpoint))
response = requests.post(url, headers=headers, json=payload)

View File

@@ -560,10 +560,14 @@ def transcribe(request: Request, file_path):
# Extract transcript from Deepgram response
try:
transcript = response_data["results"]["channels"][0]["alternatives"][0].get("transcript", "")
transcript = response_data["results"]["channels"][0]["alternatives"][
0
].get("transcript", "")
except (KeyError, IndexError) as e:
log.error(f"Malformed response from Deepgram: {str(e)}")
raise Exception("Failed to parse Deepgram response - unexpected response format")
raise Exception(
"Failed to parse Deepgram response - unexpected response format"
)
data = {"text": transcript.strip()}
# Save transcript

View File

@@ -36,6 +36,98 @@ async def export_config(user=Depends(get_admin_user)):
return get_config()
############################
# Direct Connections Config
############################
class DirectConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
}
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
async def set_direct_connections_config(
request: Request,
form_data: DirectConnectionsConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS
)
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
}
############################
# CodeInterpreterConfig
############################
class CodeInterpreterConfigForm(BaseModel):
ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
}
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def set_code_interpreter_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
form_data.CODE_INTERPRETER_JUPYTER_URL
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
return {
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
}
############################
# SetDefaultModels
############################

View File

@@ -3,30 +3,22 @@ import os
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from urllib.parse import quote
from open_webui.storage.provider import Storage
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -41,7 +33,10 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = {},
):
log.info(f"file.content_type: {file.content_type}")
try:
@@ -65,6 +60,7 @@ def upload_file(
"name": name,
"content_type": file.content_type,
"size": len(contents),
"data": file_metadata,
},
}
),
@@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
Storage.delete_all_files()
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
Storage.delete_file(file.path)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),

View File

@@ -1,32 +1,26 @@
import asyncio
import base64
import io
import json
import logging
import mimetypes
import re
import uuid
from pathlib import Path
from typing import Optional
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
@@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
def load_b64_image_data(b64_str):
try:
image_id = str(uuid.uuid4())
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
else:
image_filename = f"{image_id}.png"
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
mime_type = "image/png"
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
return img_data, mime_type
except Exception as e:
log.exception(f"Error saving image: {e}")
log.exception(f"Error loading image data: {e}")
return None
def save_url_image(url, headers=None):
image_id = str(uuid.uuid4())
def load_url_image_data(url, headers=None):
try:
if headers:
r = requests.get(url, headers=headers)
@@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_filename
return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
@@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
return None
def upload_image(request, image_metadata, image_data, content_type, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
headers={
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@router.post("/generations")
async def image_generations(
request: Request,
@@ -500,13 +478,9 @@ async def image_generations(
images = []
for image in res["data"]:
image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(data, f)
image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, data, image_data, content_type, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@@ -552,14 +526,15 @@ async def image_generations(
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_filename = save_url_image(image["url"], headers)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(form_data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image(
request,
form_data.model_dump(exclude_none=True),
image_data,
content_type,
user,
)
images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@@ -604,13 +579,15 @@ async def image_generations(
images = []
for image in res["images"]:
image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
image_data, content_type = load_b64_image_data(image)
url = upload_image(
request,
{**data, "info": res["info"]},
image_data,
content_type,
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e

View File

@@ -1424,11 +1424,11 @@ async def upload_model(
os.makedirs(UPLOAD_DIR, exist_ok=True)
# --- P1: save file locally ---
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
with open(file_path, "wb") as out_f:
while True:
chunk = file.file.read(chunk_size)
#log.info(f"Chunk: {str(chunk)}") # DEBUG
# log.info(f"Chunk: {str(chunk)}") # DEBUG
if not chunk:
break
out_f.write(chunk)
@@ -1436,15 +1436,15 @@ async def upload_model(
async def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
# --- P2: SSE progress + calculate sha256 hash ---
file_hash = calculate_sha256(file_path, chunk_size)
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
try:
with open(file_path, "rb") as f:
bytes_read = 0
while chunk := f.read(chunk_size):
while chunk := f.read(chunk_size):
bytes_read += len(chunk)
progress = round(bytes_read / total_size * 100, 2)
data_msg = {
@@ -1460,25 +1460,23 @@ async def upload_model(
response = requests.post(url, data=f)
if response.ok:
log.info(f"Uploaded to /api/blobs") # DEBUG
log.info(f"Uploaded to /api/blobs") # DEBUG
# Remove local file
os.remove(file_path)
# Create model in ollama
model_name, ext = os.path.splitext(file.filename)
log.info(f"Created Model: {model_name}") # DEBUG
log.info(f"Created Model: {model_name}") # DEBUG
create_payload = {
"model": model_name,
# Reference the file by its original name => the uploaded blob's digest
"files": {
file.filename: f"sha256:{file_hash}"
},
"files": {file.filename: f"sha256:{file_hash}"},
}
log.info(f"Model Payload: {create_payload}") # DEBUG
log.info(f"Model Payload: {create_payload}") # DEBUG
# Call ollama /api/create
#https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
create_resp = requests.post(
url=f"{ollama_url}/api/create",
headers={"Content-Type": "application/json"},
@@ -1486,7 +1484,7 @@ async def upload_model(
)
if create_resp.ok:
log.info(f"API SUCCESS!") # DEBUG
log.info(f"API SUCCESS!") # DEBUG
done_msg = {
"done": True,
"blob": f"sha256:{file_hash}",
@@ -1506,4 +1504,4 @@ async def upload_model(
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
return StreamingResponse(file_process_stream(), media_type="text/event-stream")

View File

@@ -45,6 +45,7 @@ from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.web.brave import search_brave
from open_webui.retrieval.web.kagi import search_kagi
from open_webui.retrieval.web.mojeek import search_mojeek
from open_webui.retrieval.web.bocha import search_bocha
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
from open_webui.retrieval.web.google_pse import search_google_pse
from open_webui.retrieval.web.jina_search import search_jina
@@ -379,6 +380,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@@ -429,6 +431,7 @@ class WebSearchConfig(BaseModel):
brave_search_api_key: Optional[str] = None
kagi_search_api_key: Optional[str] = None
mojeek_search_api_key: Optional[str] = None
bocha_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
@@ -525,6 +528,9 @@ async def update_rag_config(
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
form_data.web.search.mojeek_search_api_key
)
request.app.state.config.BOCHA_SEARCH_API_KEY = (
form_data.web.search.bocha_search_api_key
)
request.app.state.config.SERPSTACK_API_KEY = (
form_data.web.search.serpstack_api_key
)
@@ -591,6 +597,7 @@ async def update_rag_config(
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@@ -1113,6 +1120,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- BRAVE_SEARCH_API_KEY
- KAGI_SEARCH_API_KEY
- MOJEEK_SEARCH_API_KEY
- BOCHA_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
@@ -1180,6 +1188,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
elif engine == "bocha":
if request.app.state.config.BOCHA_SEARCH_API_KEY:
return search_bocha(
request.app.state.config.BOCHA_SEARCH_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if request.app.state.config.SERPSTACK_API_KEY:
return search_serpstack(

View File

@@ -153,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
if user:
return user.settings
else:

View File

@@ -94,7 +94,7 @@ class S3StorageProvider(StorageProvider):
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
@@ -108,7 +108,7 @@ class S3StorageProvider(StorageProvider):
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
try:
@@ -137,7 +137,8 @@ class S3StorageProvider(StorageProvider):
if "Contents" in response:
for content in response["Contents"]:
# Skip objects that were not uploaded from open-webui in the first place
if not content["Key"].startswith(self.key_prefix): continue
if not content["Key"].startswith(self.key_prefix):
continue
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
@@ -150,11 +151,12 @@ class S3StorageProvider(StorageProvider):
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
def _extract_s3_key(self, full_file_path: str) -> str:
return '/'.join(full_file_path.split("//")[1].split("/")[1:])
return "/".join(full_file_path.split("//")[1].split("/")[1:])
def _get_local_file_path(self, s3_key: str) -> str:
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
class GCSStorageProvider(StorageProvider):
def __init__(self):
self.bucket_name = GCS_BUCKET_NAME

View File

@@ -0,0 +1,148 @@
import asyncio
import json
import uuid
import websockets
import requests
from urllib.parse import urljoin
async def execute_code_jupyter(
jupyter_url, code, token=None, password=None, timeout=10
):
"""
Executes Python code in a Jupyter kernel.
Supports authentication with a token or password.
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 10s)
:return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
"""
session = requests.Session() # Maintain cookies
headers = {} # Headers for requests
# Authenticate using password
if password and not token:
try:
login_url = urljoin(jupyter_url, "/login")
response = session.get(login_url)
response.raise_for_status()
xsrf_token = session.cookies.get("_xsrf")
if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token")
login_data = {"_xsrf": xsrf_token, "password": password}
login_response = session.post(
login_url, data=login_data, cookies=session.cookies
)
login_response.raise_for_status()
headers["X-XSRFToken"] = xsrf_token
except Exception as e:
return {
"stdout": "",
"stderr": f"Authentication Error: {str(e)}",
"result": "",
}
# Construct API URLs with authentication token if provided
params = f"?token={token}" if token else ""
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
try:
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
response.raise_for_status()
kernel_id = response.json()["id"]
websocket_url = urljoin(
jupyter_url.replace("http", "ws"),
f"/api/kernels/{kernel_id}/channels{params}",
)
ws_headers = {}
if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
cookies = {name: value for name, value in session.cookies.items()}
ws_headers["Cookie"] = "; ".join(
[f"{name}={value}" for name, value in cookies.items()]
)
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
msg_id = str(uuid.uuid4())
execute_request = {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
await ws.send(json.dumps(execute_request))
stdout, stderr, result = "", "", []
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
msg_type = message_data.get("msg_type")
if msg_type == "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
elif msg_type in ("execute_result", "display_data"):
data = message_data["content"]["data"]
if "image/png" in data:
result.append(
f"data:image/png;base64,{data['image/png']}"
)
elif "text/plain" in data:
result.append(data["text/plain"])
elif msg_type == "error":
stderr += "\n".join(message_data["content"]["traceback"])
elif (
msg_type == "status"
and message_data["content"]["execution_state"] == "idle"
):
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
except Exception as e:
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
finally:
if kernel_id:
requests.delete(
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
)
return {
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"result": "\n".join(result).strip() if result else "",
}

View File

@@ -72,7 +72,7 @@ from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.tasks import create_task
@@ -684,7 +684,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
if "code_interpreter" in features and features["code_interpreter"]:
form_data["messages"] = add_or_update_user_message(
DEFAULT_CODE_INTERPRETER_PROMPT, form_data["messages"]
(
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != ""
else DEFAULT_CODE_INTERPRETER_PROMPT
),
form_data["messages"],
)
try:
@@ -1639,21 +1644,60 @@ async def process_chat_response(
content_blocks[-1]["type"] == "code_interpreter"
and retries < MAX_RETRIES
):
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
retries += 1
log.debug(f"Attempt count: {retries}")
output = ""
try:
if content_blocks[-1]["attributes"].get("type") == "code":
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": content_blocks[-1]["content"],
},
code = content_blocks[-1]["content"]
if (
request.app.state.config.CODE_INTERPRETER_ENGINE
== "pyodide"
):
output = await event_caller(
{
"type": "execute:python",
"data": {
"id": str(uuid4()),
"code": code,
},
}
)
elif (
request.app.state.config.CODE_INTERPRETER_ENGINE
== "jupyter"
):
output = await execute_code_jupyter(
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
code,
(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
== "token"
else None
),
(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
== "password"
else None
),
)
else:
output = {
"stdout": "Code interpreter engine not configured."
}
)
if isinstance(output, dict):
stdout = output.get("stdout", "")
@@ -1687,6 +1731,38 @@ async def process_chat_response(
)
output["stdout"] = "\n".join(stdoutLines)
result = output.get("result", "")
if result:
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
)
resultLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
)
output["result"] = "\n".join(resultLines)
except Exception as e:
output = str(e)

View File

@@ -245,7 +245,7 @@ def get_gravatar_url(email):
def calculate_sha256(file_path, chunk_size):
#Compute SHA-256 hash of a file efficiently in chunks
# Compute SHA-256 hash of a file efficiently in chunks
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(chunk_size):

View File

@@ -142,13 +142,17 @@ class OAuthManager:
log.debug(f"Oauth Groups claim: {oauth_claim}")
log.debug(f"User oauth groups: {user_oauth_groups}")
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
log.debug(f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}")
log.debug(
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
)
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
if group_model.name not in user_oauth_groups:
# Remove group from user
log.debug(f"Removing user from group {group_model.name} as it is no longer in their oauth groups")
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
)
user_ids = group_model.user_ids
user_ids = [i for i in user_ids if i != user.id]
@@ -174,7 +178,9 @@ class OAuthManager:
gm.name == group_model.name for gm in user_current_groups
):
# Add user to group
log.debug(f"Adding user to group {group_model.name} as it was found in their oauth groups")
log.debug(
f"Adding user to group {group_model.name} as it was found in their oauth groups"
)
user_ids = group_model.user_ids
user_ids.append(user.id)
@@ -289,7 +295,9 @@ class OAuthManager:
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
guessed_mime_type = mimetypes.guess_type(
picture_url
)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
@@ -307,7 +315,8 @@ class OAuthManager:
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
name = user_data.get(username_claim)
if not isinstance(name, str):
if not name:
log.warning("Username claim is missing, using email as name")
name = email
role = self.get_user_role(None, user_data)

View File

@@ -14,6 +14,12 @@ def apply_model_system_prompt_to_body(
if not system:
return form_data
# Metadata (WebUI Usage)
if metadata:
variables = metadata.get("variables", {})
if variables:
system = prompt_variables_template(system, variables)
# Legacy (API Usage)
if user:
template_params = {
@@ -25,12 +31,6 @@ def apply_model_system_prompt_to_body(
system = prompt_template(system, **template_params)
# Metadata (WebUI Usage)
if metadata:
variables = metadata.get("variables", {})
if variables:
system = prompt_variables_template(system, variables)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List
from html import escape
from markdown import markdown
@@ -41,13 +42,13 @@ class PDFGenerator:
def _build_html_message(self, message: Dict[str, Any]) -> str:
"""Build HTML for a single message."""
role = message.get("role", "user")
content = message.get("content", "")
role = escape(message.get("role", "user"))
content = escape(message.get("content", ""))
timestamp = message.get("timestamp")
model = message.get("model") if role == "assistant" else ""
model = escape(message.get("model") if role == "assistant" else "")
date_str = self.format_timestamp(timestamp) if timestamp else ""
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
# extends pymdownx extension to convert markdown to html.
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
@@ -76,6 +77,7 @@ class PDFGenerator:
def _generate_html_body(self) -> str:
"""Generate the full HTML body for the PDF."""
escaped_title = escape(self.form_data.title)
return f"""
<html>
<head>
@@ -84,7 +86,7 @@ class PDFGenerator:
<body>
<div>
<div>
<h2>{self.form_data.title}</h2>
<h2>{escaped_title}</h2>
{self.messages_html}
</div>
</div>

View File

@@ -73,7 +73,9 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
"type": "function",
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": f"{tool_call.get('function', {}).get('arguments', {})}",
"arguments": json.dumps(
tool_call.get("function", {}).get("arguments", {})
),
},
}
openai_tool_calls.append(openai_tool_call)