mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into playwright
This commit is contained in:
@@ -683,6 +683,17 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||
CACHE_DIR = f"{DATA_DIR}/cache"
|
||||
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# DIRECT CONNECTIONS
|
||||
####################################
|
||||
|
||||
ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
|
||||
"ENABLE_DIRECT_CONNECTIONS",
|
||||
"direct.enable",
|
||||
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
|
||||
)
|
||||
|
||||
####################################
|
||||
# OLLAMA_BASE_URL
|
||||
####################################
|
||||
@@ -1326,6 +1337,54 @@ Your task is to synthesize these responses into a single, high-quality response.
|
||||
Responses from models: {{responses}}"""
|
||||
|
||||
|
||||
####################################
|
||||
# Code Interpreter
|
||||
####################################
|
||||
|
||||
ENABLE_CODE_INTERPRETER = PersistentConfig(
|
||||
"ENABLE_CODE_INTERPRETER",
|
||||
"code_interpreter.enable",
|
||||
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_ENGINE = PersistentConfig(
|
||||
"CODE_INTERPRETER_ENGINE",
|
||||
"code_interpreter.engine",
|
||||
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE",
|
||||
"code_interpreter.prompt_template",
|
||||
os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_URL",
|
||||
"code_interpreter.jupyter.url",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH",
|
||||
"code_interpreter.jupyter.auth",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""),
|
||||
)
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
|
||||
"code_interpreter.jupyter.auth_token",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
|
||||
"code_interpreter.jupyter.auth_password",
|
||||
os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""),
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||
#### Tools Available
|
||||
|
||||
@@ -1336,9 +1395,8 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
|
||||
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
|
||||
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
|
||||
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
|
||||
- If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.
|
||||
- **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
|
||||
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
|
||||
- **If a link to an image, audio, or any file is provided in markdown format, ALWAYS regurgitate explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
|
||||
|
||||
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
|
||||
|
||||
@@ -1691,6 +1749,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
|
||||
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
BOCHA_SEARCH_API_KEY = PersistentConfig(
|
||||
"BOCHA_SEARCH_API_KEY",
|
||||
"rag.web.search.bocha_search_api_key",
|
||||
os.getenv("BOCHA_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPSTACK_API_KEY = PersistentConfig(
|
||||
"SERPSTACK_API_KEY",
|
||||
"rag.web.search.serpstack_api_key",
|
||||
|
||||
@@ -97,6 +97,16 @@ from open_webui.config import (
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
OPENAI_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
# Code Interpreter
|
||||
ENABLE_CODE_INTERPRETER,
|
||||
CODE_INTERPRETER_ENGINE,
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
CODE_INTERPRETER_JUPYTER_URL,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
# Image
|
||||
AUTOMATIC1111_API_AUTH,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
@@ -183,6 +193,7 @@ from open_webui.config import (
|
||||
EXA_API_KEY,
|
||||
KAGI_SEARCH_API_KEY,
|
||||
MOJEEK_SEARCH_API_KEY,
|
||||
BOCHA_SEARCH_API_KEY,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
GOOGLE_DRIVE_CLIENT_ID,
|
||||
@@ -325,7 +336,11 @@ class SPAStaticFiles(StaticFiles):
|
||||
return await super().get_response(path, scope)
|
||||
except (HTTPException, StarletteHTTPException) as ex:
|
||||
if ex.status_code == 404:
|
||||
return await super().get_response("index.html", scope)
|
||||
if path.endswith(".js"):
|
||||
# Return 404 for javascript files
|
||||
raise ex
|
||||
else:
|
||||
return await super().get_response("index.html", scope)
|
||||
else:
|
||||
raise ex
|
||||
|
||||
@@ -392,6 +407,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
|
||||
|
||||
app.state.OPENAI_MODELS = {}
|
||||
|
||||
########################################
|
||||
#
|
||||
# DIRECT CONNECTIONS
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||
|
||||
########################################
|
||||
#
|
||||
# WEBUI
|
||||
@@ -517,6 +540,7 @@ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
||||
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
|
||||
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
|
||||
app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
|
||||
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
||||
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
||||
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
||||
@@ -574,6 +598,24 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE INTERPRETER
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
|
||||
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
|
||||
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
|
||||
########################################
|
||||
#
|
||||
@@ -759,6 +801,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
|
||||
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
|
||||
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
||||
|
||||
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
|
||||
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
|
||||
|
||||
@@ -1017,15 +1060,17 @@ async def get_app_config(request: Request):
|
||||
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
||||
**(
|
||||
{
|
||||
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
|
||||
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
}
|
||||
if user is not None
|
||||
else {}
|
||||
|
||||
@@ -271,6 +271,24 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user_settings = db.query(User).filter_by(id=id).first().settings
|
||||
|
||||
if user_settings is None:
|
||||
user_settings = {}
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_user_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
# Remove User from Groups
|
||||
|
||||
@@ -120,18 +120,12 @@ class OpenSearchClient:
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": []
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({
|
||||
"term": {field: value}
|
||||
})
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
|
||||
size = limit if limit else 10
|
||||
|
||||
@@ -139,7 +133,7 @@ class OpenSearchClient:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}",
|
||||
body=query_body,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
72
backend/open_webui/retrieval/web/bocha.py
Normal file
72
backend/open_webui/retrieval/web/bocha.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import json
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
def _parse_response(response):
|
||||
result = {}
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
if "webPages" in data:
|
||||
webPages = data["webPages"]
|
||||
if "value" in webPages:
|
||||
result["webpage"] = [
|
||||
{
|
||||
"id": item.get("id", ""),
|
||||
"name": item.get("name", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("snippet", ""),
|
||||
"summary": item.get("summary", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
"siteIcon": item.get("siteIcon", ""),
|
||||
"datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""),
|
||||
}
|
||||
for item in webPages["value"]
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def search_bocha(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Bocha Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = json.dumps({
|
||||
"query": query,
|
||||
"summary": True,
|
||||
"freshness": "noLimit",
|
||||
"count": count
|
||||
})
|
||||
|
||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||
response.raise_for_status()
|
||||
results = _parse_response(response.json())
|
||||
print(results)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("name"),
|
||||
snippet=result.get("summary")
|
||||
)
|
||||
for result in results.get("webpage", [])[:count]
|
||||
]
|
||||
|
||||
@@ -8,7 +8,6 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
@@ -17,34 +16,51 @@ def search_google_pse(
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||
Handles pagination for counts greater than 10.
|
||||
|
||||
Args:
|
||||
api_key (str): A Programmable Search Engine API key
|
||||
search_engine_id (str): A Programmable Search Engine ID
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
|
||||
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResult objects.
|
||||
"""
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
params = {
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": count,
|
||||
}
|
||||
all_results = []
|
||||
start_index = 1 # Google PSE start parameter is 1-based
|
||||
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
while count > 0:
|
||||
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
|
||||
params = {
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": num_results_this_page,
|
||||
"start": start_index,
|
||||
}
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(results) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
all_results = get_filtered_results(all_results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results
|
||||
for result in all_results
|
||||
]
|
||||
|
||||
@@ -25,13 +25,10 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": api_key,
|
||||
"X-Retain-Images": "none"
|
||||
"X-Retain-Images": "none",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"q": query,
|
||||
"count": count if count <= 10 else 10
|
||||
}
|
||||
payload = {"q": query, "count": count if count <= 10 else 10}
|
||||
|
||||
url = str(URL(jina_search_endpoint))
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
@@ -560,10 +560,14 @@ def transcribe(request: Request, file_path):
|
||||
|
||||
# Extract transcript from Deepgram response
|
||||
try:
|
||||
transcript = response_data["results"]["channels"][0]["alternatives"][0].get("transcript", "")
|
||||
transcript = response_data["results"]["channels"][0]["alternatives"][
|
||||
0
|
||||
].get("transcript", "")
|
||||
except (KeyError, IndexError) as e:
|
||||
log.error(f"Malformed response from Deepgram: {str(e)}")
|
||||
raise Exception("Failed to parse Deepgram response - unexpected response format")
|
||||
raise Exception(
|
||||
"Failed to parse Deepgram response - unexpected response format"
|
||||
)
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# Save transcript
|
||||
|
||||
@@ -36,6 +36,98 @@ async def export_config(user=Depends(get_admin_user)):
|
||||
return get_config()
|
||||
|
||||
|
||||
############################
|
||||
# Direct Connections Config
|
||||
############################
|
||||
|
||||
|
||||
class DirectConnectionsConfigForm(BaseModel):
|
||||
ENABLE_DIRECT_CONNECTIONS: bool
|
||||
|
||||
|
||||
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def set_direct_connections_config(
|
||||
request: Request,
|
||||
form_data: DirectConnectionsConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
|
||||
form_data.ENABLE_DIRECT_CONNECTIONS
|
||||
)
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# CodeInterpreterConfig
|
||||
############################
|
||||
class CodeInterpreterConfigForm(BaseModel):
|
||||
ENABLE_CODE_INTERPRETER: bool
|
||||
CODE_INTERPRETER_ENGINE: str
|
||||
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
|
||||
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
|
||||
|
||||
|
||||
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
|
||||
async def set_code_interpreter_config(
|
||||
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
|
||||
form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_URL
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
)
|
||||
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
)
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
|
||||
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
)
|
||||
|
||||
return {
|
||||
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
|
||||
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
|
||||
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
|
||||
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
|
||||
@@ -3,30 +3,22 @@ import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import mimetypes
|
||||
from urllib.parse import quote
|
||||
|
||||
from open_webui.storage.provider import Storage
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
Files,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.storage.provider import Storage
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from pydantic import BaseModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -41,7 +33,10 @@ router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=FileModelResponse)
|
||||
def upload_file(
|
||||
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = {},
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
@@ -65,6 +60,7 @@ def upload_file(
|
||||
"name": name,
|
||||
"content_type": file.content_type,
|
||||
"size": len(contents),
|
||||
"data": file_metadata,
|
||||
},
|
||||
}
|
||||
),
|
||||
@@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
||||
Storage.delete_all_files()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
log.error("Error deleting files")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
@@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error getting file content")
|
||||
log.error("Error getting file content")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||
@@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error getting file content")
|
||||
log.error("Error getting file content")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||
@@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
Storage.delete_file(file.path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting files")
|
||||
log.error("Error deleting files")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
|
||||
@@ -1,32 +1,26 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||
from open_webui.routers.files import upload_file
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.images.comfyui import (
|
||||
ComfyUIGenerateImageForm,
|
||||
ComfyUIWorkflow,
|
||||
comfyui_generate_image,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
@@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
||||
async def update_image_config(
|
||||
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
set_image_model(request, form_data.MODEL)
|
||||
|
||||
pattern = r"^\d+x\d+$"
|
||||
@@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
|
||||
negative_prompt: Optional[str] = None
|
||||
|
||||
|
||||
def save_b64_image(b64_str):
|
||||
def load_b64_image_data(b64_str):
|
||||
try:
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
if "," in b64_str:
|
||||
header, encoded = b64_str.split(",", 1)
|
||||
mime_type = header.split(";")[0]
|
||||
|
||||
img_data = base64.b64decode(encoded)
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
else:
|
||||
image_filename = f"{image_id}.png"
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
||||
|
||||
mime_type = "image/png"
|
||||
img_data = base64.b64decode(b64_str)
|
||||
|
||||
# Write the image data to a file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
|
||||
return img_data, mime_type
|
||||
except Exception as e:
|
||||
log.exception(f"Error saving image: {e}")
|
||||
log.exception(f"Error loading image data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_url_image(url, headers=None):
|
||||
image_id = str(uuid.uuid4())
|
||||
def load_url_image_data(url, headers=None):
|
||||
try:
|
||||
if headers:
|
||||
r = requests.get(url, headers=headers)
|
||||
@@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
|
||||
r.raise_for_status()
|
||||
if r.headers["content-type"].split("/")[0] == "image":
|
||||
mime_type = r.headers["content-type"]
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
if not image_format:
|
||||
raise ValueError("Could not determine image type from MIME type")
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
||||
with open(file_path, "wb") as image_file:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
image_file.write(chunk)
|
||||
return image_filename
|
||||
return r.content, mime_type
|
||||
else:
|
||||
log.error("Url does not point to an image.")
|
||||
return None
|
||||
@@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
|
||||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
|
||||
headers={
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/generations")
|
||||
async def image_generations(
|
||||
request: Request,
|
||||
@@ -500,13 +478,9 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_filename = save_b64_image(image["b64_json"])
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
@@ -552,14 +526,15 @@ async def image_generations(
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
|
||||
image_filename = save_url_image(image["url"], headers)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(form_data.model_dump(exclude_none=True), f)
|
||||
|
||||
log.debug(f"images: {images}")
|
||||
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
image_data,
|
||||
content_type,
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
elif (
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||
@@ -604,13 +579,15 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res["images"]:
|
||||
image_filename = save_b64_image(image)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump({**data, "info": res["info"]}, f)
|
||||
|
||||
image_data, content_type = load_b64_image_data(image)
|
||||
url = upload_image(
|
||||
request,
|
||||
{**data, "info": res["info"]},
|
||||
image_data,
|
||||
content_type,
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
except Exception as e:
|
||||
error = e
|
||||
|
||||
@@ -1424,11 +1424,11 @@ async def upload_model(
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
# --- P1: save file locally ---
|
||||
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
||||
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
||||
with open(file_path, "wb") as out_f:
|
||||
while True:
|
||||
chunk = file.file.read(chunk_size)
|
||||
#log.info(f"Chunk: {str(chunk)}") # DEBUG
|
||||
# log.info(f"Chunk: {str(chunk)}") # DEBUG
|
||||
if not chunk:
|
||||
break
|
||||
out_f.write(chunk)
|
||||
@@ -1436,15 +1436,15 @@ async def upload_model(
|
||||
async def file_process_stream():
|
||||
nonlocal ollama_url
|
||||
total_size = os.path.getsize(file_path)
|
||||
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
|
||||
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
|
||||
|
||||
# --- P2: SSE progress + calculate sha256 hash ---
|
||||
file_hash = calculate_sha256(file_path, chunk_size)
|
||||
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
|
||||
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
bytes_read = 0
|
||||
while chunk := f.read(chunk_size):
|
||||
while chunk := f.read(chunk_size):
|
||||
bytes_read += len(chunk)
|
||||
progress = round(bytes_read / total_size * 100, 2)
|
||||
data_msg = {
|
||||
@@ -1460,25 +1460,23 @@ async def upload_model(
|
||||
response = requests.post(url, data=f)
|
||||
|
||||
if response.ok:
|
||||
log.info(f"Uploaded to /api/blobs") # DEBUG
|
||||
log.info(f"Uploaded to /api/blobs") # DEBUG
|
||||
# Remove local file
|
||||
os.remove(file_path)
|
||||
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {
|
||||
file.filename: f"sha256:{file_hash}"
|
||||
},
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
# Call ollama /api/create
|
||||
#https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
||||
create_resp = requests.post(
|
||||
url=f"{ollama_url}/api/create",
|
||||
headers={"Content-Type": "application/json"},
|
||||
@@ -1486,7 +1484,7 @@ async def upload_model(
|
||||
)
|
||||
|
||||
if create_resp.ok:
|
||||
log.info(f"API SUCCESS!") # DEBUG
|
||||
log.info(f"API SUCCESS!") # DEBUG
|
||||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
@@ -1506,4 +1504,4 @@ async def upload_model(
|
||||
res = {"error": str(e)}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
|
||||
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
|
||||
|
||||
@@ -45,6 +45,7 @@ from open_webui.retrieval.web.utils import get_web_loader
|
||||
from open_webui.retrieval.web.brave import search_brave
|
||||
from open_webui.retrieval.web.kagi import search_kagi
|
||||
from open_webui.retrieval.web.mojeek import search_mojeek
|
||||
from open_webui.retrieval.web.bocha import search_bocha
|
||||
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
||||
from open_webui.retrieval.web.google_pse import search_google_pse
|
||||
from open_webui.retrieval.web.jina_search import search_jina
|
||||
@@ -379,6 +380,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||
@@ -429,6 +431,7 @@ class WebSearchConfig(BaseModel):
|
||||
brave_search_api_key: Optional[str] = None
|
||||
kagi_search_api_key: Optional[str] = None
|
||||
mojeek_search_api_key: Optional[str] = None
|
||||
bocha_search_api_key: Optional[str] = None
|
||||
serpstack_api_key: Optional[str] = None
|
||||
serpstack_https: Optional[bool] = None
|
||||
serper_api_key: Optional[str] = None
|
||||
@@ -525,6 +528,9 @@ async def update_rag_config(
|
||||
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
||||
form_data.web.search.mojeek_search_api_key
|
||||
)
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY = (
|
||||
form_data.web.search.bocha_search_api_key
|
||||
)
|
||||
request.app.state.config.SERPSTACK_API_KEY = (
|
||||
form_data.web.search.serpstack_api_key
|
||||
)
|
||||
@@ -591,6 +597,7 @@ async def update_rag_config(
|
||||
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": request.app.state.config.SERPER_API_KEY,
|
||||
@@ -1113,6 +1120,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
- BRAVE_SEARCH_API_KEY
|
||||
- KAGI_SEARCH_API_KEY
|
||||
- MOJEEK_SEARCH_API_KEY
|
||||
- BOCHA_SEARCH_API_KEY
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
@@ -1180,6 +1188,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
else:
|
||||
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "bocha":
|
||||
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
||||
return search_bocha(
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "serpstack":
|
||||
if request.app.state.config.SERPSTACK_API_KEY:
|
||||
return search_serpstack(
|
||||
|
||||
@@ -153,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: UserSettings, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
|
||||
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
||||
@@ -94,7 +94,7 @@ class S3StorageProvider(StorageProvider):
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
)
|
||||
self.bucket_name = S3_BUCKET_NAME
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to S3 storage."""
|
||||
@@ -108,7 +108,7 @@ class S3StorageProvider(StorageProvider):
|
||||
)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Error uploading file to S3: {e}")
|
||||
|
||||
|
||||
def get_file(self, file_path: str) -> str:
|
||||
"""Handles downloading of the file from S3 storage."""
|
||||
try:
|
||||
@@ -137,7 +137,8 @@ class S3StorageProvider(StorageProvider):
|
||||
if "Contents" in response:
|
||||
for content in response["Contents"]:
|
||||
# Skip objects that were not uploaded from open-webui in the first place
|
||||
if not content["Key"].startswith(self.key_prefix): continue
|
||||
if not content["Key"].startswith(self.key_prefix):
|
||||
continue
|
||||
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket_name, Key=content["Key"]
|
||||
@@ -150,11 +151,12 @@ class S3StorageProvider(StorageProvider):
|
||||
|
||||
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
|
||||
def _extract_s3_key(self, full_file_path: str) -> str:
|
||||
return '/'.join(full_file_path.split("//")[1].split("/")[1:])
|
||||
|
||||
return "/".join(full_file_path.split("//")[1].split("/")[1:])
|
||||
|
||||
def _get_local_file_path(self, s3_key: str) -> str:
|
||||
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
|
||||
|
||||
|
||||
class GCSStorageProvider(StorageProvider):
|
||||
def __init__(self):
|
||||
self.bucket_name = GCS_BUCKET_NAME
|
||||
|
||||
148
backend/open_webui/utils/code_interpreter.py
Normal file
148
backend/open_webui/utils/code_interpreter.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import websockets
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
async def execute_code_jupyter(
|
||||
jupyter_url, code, token=None, password=None, timeout=10
|
||||
):
|
||||
"""
|
||||
Executes Python code in a Jupyter kernel.
|
||||
Supports authentication with a token or password.
|
||||
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
|
||||
:param code: Code to execute
|
||||
:param token: Jupyter authentication token (optional)
|
||||
:param password: Jupyter password (optional)
|
||||
:param timeout: WebSocket timeout in seconds (default: 10s)
|
||||
:return: Dictionary with stdout, stderr, and result
|
||||
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
|
||||
"""
|
||||
session = requests.Session() # Maintain cookies
|
||||
headers = {} # Headers for requests
|
||||
|
||||
# Authenticate using password
|
||||
if password and not token:
|
||||
try:
|
||||
login_url = urljoin(jupyter_url, "/login")
|
||||
response = session.get(login_url)
|
||||
response.raise_for_status()
|
||||
xsrf_token = session.cookies.get("_xsrf")
|
||||
if not xsrf_token:
|
||||
raise ValueError("Failed to fetch _xsrf token")
|
||||
|
||||
login_data = {"_xsrf": xsrf_token, "password": password}
|
||||
login_response = session.post(
|
||||
login_url, data=login_data, cookies=session.cookies
|
||||
)
|
||||
login_response.raise_for_status()
|
||||
headers["X-XSRFToken"] = xsrf_token
|
||||
except Exception as e:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Authentication Error: {str(e)}",
|
||||
"result": "",
|
||||
}
|
||||
|
||||
# Construct API URLs with authentication token if provided
|
||||
params = f"?token={token}" if token else ""
|
||||
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
|
||||
|
||||
try:
|
||||
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
|
||||
response.raise_for_status()
|
||||
kernel_id = response.json()["id"]
|
||||
|
||||
websocket_url = urljoin(
|
||||
jupyter_url.replace("http", "ws"),
|
||||
f"/api/kernels/{kernel_id}/channels{params}",
|
||||
)
|
||||
|
||||
ws_headers = {}
|
||||
if password and not token:
|
||||
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
|
||||
cookies = {name: value for name, value in session.cookies.items()}
|
||||
ws_headers["Cookie"] = "; ".join(
|
||||
[f"{name}={value}" for name, value in cookies.items()]
|
||||
)
|
||||
|
||||
async with websockets.connect(
|
||||
websocket_url, additional_headers=ws_headers
|
||||
) as ws:
|
||||
msg_id = str(uuid.uuid4())
|
||||
execute_request = {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "user",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": "",
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"channel": "shell",
|
||||
}
|
||||
await ws.send(json.dumps(execute_request))
|
||||
|
||||
stdout, stderr, result = "", "", []
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(ws.recv(), timeout)
|
||||
message_data = json.loads(message)
|
||||
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
|
||||
msg_type = message_data.get("msg_type")
|
||||
|
||||
if msg_type == "stream":
|
||||
if message_data["content"]["name"] == "stdout":
|
||||
stdout += message_data["content"]["text"]
|
||||
elif message_data["content"]["name"] == "stderr":
|
||||
stderr += message_data["content"]["text"]
|
||||
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
data = message_data["content"]["data"]
|
||||
if "image/png" in data:
|
||||
result.append(
|
||||
f"data:image/png;base64,{data['image/png']}"
|
||||
)
|
||||
elif "text/plain" in data:
|
||||
result.append(data["text/plain"])
|
||||
|
||||
elif msg_type == "error":
|
||||
stderr += "\n".join(message_data["content"]["traceback"])
|
||||
|
||||
elif (
|
||||
msg_type == "status"
|
||||
and message_data["content"]["execution_state"] == "idle"
|
||||
):
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
stderr += "\nExecution timed out."
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
|
||||
|
||||
finally:
|
||||
if kernel_id:
|
||||
requests.delete(
|
||||
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
|
||||
)
|
||||
|
||||
return {
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr.strip(),
|
||||
"result": "\n".join(result).strip() if result else "",
|
||||
}
|
||||
@@ -72,7 +72,7 @@ from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
@@ -684,7 +684,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
|
||||
if "code_interpreter" in features and features["code_interpreter"]:
|
||||
form_data["messages"] = add_or_update_user_message(
|
||||
DEFAULT_CODE_INTERPRETER_PROMPT, form_data["messages"]
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE
|
||||
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != ""
|
||||
else DEFAULT_CODE_INTERPRETER_PROMPT
|
||||
),
|
||||
form_data["messages"],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1639,21 +1644,60 @@ async def process_chat_response(
|
||||
content_blocks[-1]["type"] == "code_interpreter"
|
||||
and retries < MAX_RETRIES
|
||||
):
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"content": serialize_content_blocks(content_blocks),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
retries += 1
|
||||
log.debug(f"Attempt count: {retries}")
|
||||
|
||||
output = ""
|
||||
try:
|
||||
if content_blocks[-1]["attributes"].get("type") == "code":
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": content_blocks[-1]["content"],
|
||||
},
|
||||
code = content_blocks[-1]["content"]
|
||||
|
||||
if (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
== "pyodide"
|
||||
):
|
||||
output = await event_caller(
|
||||
{
|
||||
"type": "execute:python",
|
||||
"data": {
|
||||
"id": str(uuid4()),
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif (
|
||||
request.app.state.config.CODE_INTERPRETER_ENGINE
|
||||
== "jupyter"
|
||||
):
|
||||
output = await execute_code_jupyter(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
|
||||
code,
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
|
||||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
== "token"
|
||||
else None
|
||||
),
|
||||
(
|
||||
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
|
||||
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH
|
||||
== "password"
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
output = {
|
||||
"stdout": "Code interpreter engine not configured."
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
stdout = output.get("stdout", "")
|
||||
@@ -1687,6 +1731,38 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
output["stdout"] = "\n".join(stdoutLines)
|
||||
|
||||
result = output.get("result", "")
|
||||
|
||||
if result:
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
)
|
||||
|
||||
resultLines[idx] = (
|
||||
f""
|
||||
)
|
||||
|
||||
output["result"] = "\n".join(resultLines)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ def get_gravatar_url(email):
|
||||
|
||||
|
||||
def calculate_sha256(file_path, chunk_size):
|
||||
#Compute SHA-256 hash of a file efficiently in chunks
|
||||
# Compute SHA-256 hash of a file efficiently in chunks
|
||||
sha256 = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
|
||||
@@ -142,13 +142,17 @@ class OAuthManager:
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
log.debug(f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}")
|
||||
log.debug(
|
||||
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||
)
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if group_model.name not in user_oauth_groups:
|
||||
# Remove group from user
|
||||
log.debug(f"Removing user from group {group_model.name} as it is no longer in their oauth groups")
|
||||
log.debug(
|
||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
@@ -174,7 +178,9 @@ class OAuthManager:
|
||||
gm.name == group_model.name for gm in user_current_groups
|
||||
):
|
||||
# Add user to group
|
||||
log.debug(f"Adding user to group {group_model.name} as it was found in their oauth groups")
|
||||
log.debug(
|
||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
@@ -289,7 +295,9 @@ class OAuthManager:
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
guessed_mime_type = mimetypes.guess_type(
|
||||
picture_url
|
||||
)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
@@ -307,7 +315,8 @@ class OAuthManager:
|
||||
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
||||
|
||||
name = user_data.get(username_claim)
|
||||
if not isinstance(name, str):
|
||||
if not name:
|
||||
log.warning("Username claim is missing, using email as name")
|
||||
name = email
|
||||
|
||||
role = self.get_user_role(None, user_data)
|
||||
|
||||
@@ -14,6 +14,12 @@ def apply_model_system_prompt_to_body(
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
# Metadata (WebUI Usage)
|
||||
if metadata:
|
||||
variables = metadata.get("variables", {})
|
||||
if variables:
|
||||
system = prompt_variables_template(system, variables)
|
||||
|
||||
# Legacy (API Usage)
|
||||
if user:
|
||||
template_params = {
|
||||
@@ -25,12 +31,6 @@ def apply_model_system_prompt_to_body(
|
||||
|
||||
system = prompt_template(system, **template_params)
|
||||
|
||||
# Metadata (WebUI Usage)
|
||||
if metadata:
|
||||
variables = metadata.get("variables", {})
|
||||
if variables:
|
||||
system = prompt_variables_template(system, variables)
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from html import escape
|
||||
|
||||
from markdown import markdown
|
||||
|
||||
@@ -41,13 +42,13 @@ class PDFGenerator:
|
||||
|
||||
def _build_html_message(self, message: Dict[str, Any]) -> str:
|
||||
"""Build HTML for a single message."""
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
role = escape(message.get("role", "user"))
|
||||
content = escape(message.get("content", ""))
|
||||
timestamp = message.get("timestamp")
|
||||
|
||||
model = message.get("model") if role == "assistant" else ""
|
||||
model = escape(message.get("model") if role == "assistant" else "")
|
||||
|
||||
date_str = self.format_timestamp(timestamp) if timestamp else ""
|
||||
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
|
||||
|
||||
# extends pymdownx extension to convert markdown to html.
|
||||
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
|
||||
@@ -76,6 +77,7 @@ class PDFGenerator:
|
||||
|
||||
def _generate_html_body(self) -> str:
|
||||
"""Generate the full HTML body for the PDF."""
|
||||
escaped_title = escape(self.form_data.title)
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
@@ -84,7 +86,7 @@ class PDFGenerator:
|
||||
<body>
|
||||
<div>
|
||||
<div>
|
||||
<h2>{self.form_data.title}</h2>
|
||||
<h2>{escaped_title}</h2>
|
||||
{self.messages_html}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -73,7 +73,9 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("function", {}).get("name", ""),
|
||||
"arguments": f"{tool_call.get('function', {}).get('arguments', {})}",
|
||||
"arguments": json.dumps(
|
||||
tool_call.get("function", {}).get("arguments", {})
|
||||
),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
|
||||
Reference in New Issue
Block a user