Merge branch 'open-webui:main' into main

This commit is contained in:
Igor Brai
2024-11-21 15:25:41 +02:00
committed by GitHub
264 changed files with 29675 additions and 15258 deletions

View File

@@ -32,7 +32,13 @@ from open_webui.config import (
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
@@ -47,7 +53,12 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@@ -74,6 +85,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app.state.speech_synthesiser = None
app.state.speech_speaker_embeddings_dataset = None
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
@@ -231,6 +246,21 @@ async def update_audio_config(
}
def load_speech_pipeline():
from transformers import pipeline
from datasets import load_dataset
if app.state.speech_synthesiser is None:
app.state.speech_synthesiser = pipeline(
"text-to-speech", "microsoft/speecht5_tts"
)
if app.state.speech_speaker_embeddings_dataset is None:
app.state.speech_speaker_embeddings_dataset = load_dataset(
"Matthijs/cmu-arctic-xvectors", split="validation"
)
@app.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
@@ -248,6 +278,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
try:
body = body.decode("utf-8")
body = json.loads(body)
@@ -391,6 +427,43 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(
status_code=500, detail=f"Error synthesizing speech - {response.reason}"
)
elif app.state.config.TTS_ENGINE == "transformers":
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
import torch
import soundfile as sf
load_speech_pipeline()
embeddings_dataset = app.state.speech_speaker_embeddings_dataset
speaker_index = 6799
try:
speaker_index = embeddings_dataset["filename"].index(
app.state.config.TTS_MODEL
)
except Exception:
pass
speaker_embedding = torch.tensor(
embeddings_dataset[speaker_index]["xvector"]
).unsqueeze(0)
speech = app.state.speech_synthesiser(
payload["input"],
forward_params={"speaker_embeddings": speaker_embedding},
)
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
return FileResponse(file_path)
def transcribe(file_path):

View File

@@ -35,7 +35,8 @@ from open_webui.config import (
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
@@ -47,7 +48,12 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@@ -456,6 +462,12 @@ async def image_generations(
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
data = {
"model": (
app.state.config.MODEL

View File

@@ -13,18 +13,20 @@ import requests
from open_webui.apps.webui.models.models import Models
from open_webui.config import (
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
MODEL_FILTER_LIST,
OLLAMA_BASE_URLS,
OLLAMA_API_CONFIGS,
UPLOAD_DIR,
AppConfig,
)
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
@@ -41,11 +43,18 @@ from open_webui.utils.payload import (
apply_model_system_prompt_to_body,
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
app = FastAPI()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@@ -56,12 +65,9 @@ app.add_middleware(
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
@@ -69,60 +75,98 @@ app.state.MODELS = {}
# least connections, or least response time for better resource utilization and performance optimization.
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
@app.head("/")
@app.get("/")
async def get_status():
return {"status": True}
class ConnectionVerificationForm(BaseModel):
url: str
key: Optional[str] = None
@app.post("/verify")
async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
):
url = form_data.url
key = form_data.key
headers = {}
if key:
headers["Authorization"] = f"Bearer {key}"
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get(f"{url}/api/version", headers=headers) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
}
class OllamaConfigForm(BaseModel):
enable_ollama_api: Optional[bool] = None
ENABLE_OLLAMA_API: Optional[bool] = None
OLLAMA_BASE_URLS: list[str]
OLLAMA_API_CONFIGS: dict
@app.post("/config/update")
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
# Remove any extra configs
config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
for url in list(app.state.config.OLLAMA_BASE_URLS):
if url not in config_urls:
app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
}
@app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
class UrlUpdateForm(BaseModel):
urls: list[str]
@app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.config.OLLAMA_BASE_URLS = form_data.urls
log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=3)
async def aiohttp_get(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url) as response:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
@@ -148,10 +192,18 @@ async def post_streaming_url(
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = await session.post(
url,
data=payload,
headers={"Content-Type": "application/json"},
headers=headers,
)
r.raise_for_status()
@@ -194,29 +246,62 @@ def merge_models_lists(model_lists):
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
digest = model["digest"]
if digest not in merged_models:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[digest] = model
merged_models[id] = model
else:
merged_models[digest]["urls"].append(idx)
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
async def get_all_models():
log.info("get_all_models()")
if app.state.config.ENABLE_OLLAMA_API:
tasks = [
fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
]
tasks = []
for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
if url not in app.state.config.OLLAMA_API_CONFIGS:
tasks.append(aiohttp_get(f"{url}/api/tags"))
else:
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
tasks.append(aiohttp_get(f"{url}/api/tags", key))
else:
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*tasks)
for idx, response in enumerate(responses):
if response:
url = app.state.config.OLLAMA_BASE_URLS[idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
model_ids = api_config.get("model_ids", [])
if len(model_ids) != 0 and "models" in response:
response["models"] = list(
filter(
lambda model: model["model"] in model_ids,
response["models"],
)
)
if prefix_id:
for model in response.get("models", []):
model["model"] = f"{prefix_id}.{model['model']}"
print(responses)
models = {
"models": merge_models_lists(
map(
lambda response: response["models"] if response else None, responses
lambda response: response.get("models", []) if response else None,
responses,
)
)
}
@@ -224,8 +309,6 @@ async def get_all_models():
else:
models = {"models": []}
app.state.MODELS = {model["model"]: model for model in models["models"]}
return models
@@ -234,29 +317,25 @@ async def get_all_models():
async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
models = []
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
return models
return models
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {}
if key:
headers["Authorization"] = f"Bearer {key}"
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
r.raise_for_status()
return r.json()
models = r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
@@ -273,6 +352,20 @@ async def get_ollama_tags(
detail=error_detail,
)
if user.role == "user":
# Filter models based on user access control
filtered_models = []
for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
models["models"] = filtered_models
return models
@app.get("/api/version")
@app.get("/api/version/{url_idx}")
@@ -281,7 +374,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx is None:
# returns lowest version
tasks = [
fetch_url(f"{url}/api/version")
aiohttp_get(
f"{url}/api/version",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
)
for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
@@ -361,8 +457,11 @@ async def push_model(
user=Depends(get_admin_user),
):
if url_idx is None:
if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0]
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if form_data.name in models:
url_idx = models[form_data.name]["urls"][0]
else:
raise HTTPException(
status_code=400,
@@ -411,8 +510,11 @@ async def copy_model(
user=Depends(get_admin_user),
):
if url_idx is None:
if form_data.source in app.state.MODELS:
url_idx = app.state.MODELS[form_data.source]["urls"][0]
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if form_data.source in models:
url_idx = models[form_data.source]["urls"][0]
else:
raise HTTPException(
status_code=400,
@@ -421,10 +523,18 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/copy",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -459,8 +569,11 @@ async def delete_model(
user=Depends(get_admin_user),
):
if url_idx is None:
if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0]
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if form_data.name in models:
url_idx = models[form_data.name]["urls"][0]
else:
raise HTTPException(
status_code=400,
@@ -470,11 +583,18 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
headers=headers,
)
try:
r.raise_for_status()
@@ -501,20 +621,30 @@ async def delete_model(
@app.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
if form_data.name not in app.state.MODELS:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if form_data.name not in models:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url_idx = random.choice(models[form_data.name]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/show",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
@@ -570,23 +700,26 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings(
async def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
if model in models:
url_idx = random.choice(models[model]["urls"])
else:
raise HTTPException(
status_code=400,
@@ -596,10 +729,17 @@ def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
@@ -630,20 +770,23 @@ def generate_ollama_embeddings(
)
def generate_ollama_batch_embeddings(
async def generate_ollama_batch_embeddings(
form_data: GenerateEmbedForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
if model in models:
url_idx = random.choice(models[model]["urls"])
else:
raise HTTPException(
status_code=400,
@@ -653,10 +796,17 @@ def generate_ollama_batch_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
@@ -692,7 +842,7 @@ class GenerateCompletionForm(BaseModel):
options: Optional[dict] = None
system: Optional[str] = None
template: Optional[str] = None
context: Optional[str] = None
context: Optional[list[int]] = None
stream: Optional[bool] = True
raw: Optional[bool] = None
keep_alive: Optional[Union[int, str]] = None
@@ -706,13 +856,16 @@ async def generate_completion(
user=Depends(get_verified_user),
):
if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
if model in models:
url_idx = random.choice(models[model]["urls"])
else:
raise HTTPException(
status_code=400,
@@ -720,6 +873,10 @@ async def generate_completion(
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
log.info(f"url: {url}")
return await post_streaming_url(
@@ -743,14 +900,17 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None
def get_ollama_url(url_idx: Optional[int], model: str):
async def get_ollama_url(url_idx: Optional[int], model: str):
if url_idx is None:
if model not in app.state.MODELS:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if model not in models:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url_idx = random.choice(app.state.MODELS[model]["urls"])
url_idx = random.choice(models[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
return url
@@ -768,15 +928,7 @@ async def generate_chat_completion(
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model
if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=403,
detail="Model not found",
)
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
@@ -794,13 +946,37 @@ async def generate_chat_completion(
)
payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
elif not bypass_filter:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
url = get_ollama_url(url_idx, payload["model"])
url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url(
f"{url}/api/chat",
json.dumps(payload),
@@ -817,7 +993,7 @@ class OpenAIChatMessageContent(BaseModel):
class OpenAIChatMessage(BaseModel):
role: str
content: Union[str, OpenAIChatMessageContent]
content: Union[str, list[OpenAIChatMessageContent]]
model_config = ConfigDict(extra="allow")
@@ -836,22 +1012,24 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
completion_form = OpenAIChatCompletionForm(**form_data)
try:
completion_form = OpenAIChatCompletionForm(**form_data)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=400,
detail=str(e),
)
payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
if "metadata" in payload:
del payload["metadata"]
model_id = completion_form.model
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in model_id:
model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@@ -862,12 +1040,36 @@ async def generate_openai_chat_completion(
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
# Check if user has access to the model
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
else:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
url = get_ollama_url(url_idx, payload["model"])
url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url(
f"{url}/v1/chat/completions",
json.dumps(payload),
@@ -881,31 +1083,19 @@ async def get_openai_models(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
models = []
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
return {
"data": [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
],
"object": "list",
}
model_list = await get_all_models()
models = [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in model_list["models"]
]
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -913,21 +1103,17 @@ async def get_openai_models(
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
models = r.json()
return {
"data": [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
],
"object": "list",
}
model_list = r.json()
models = [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
]
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
@@ -944,6 +1130,23 @@ async def get_openai_models(
detail=error_detail,
)
if user.role == "user":
# Filter models based on user access control
filtered_models = []
for model in models:
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
models = filtered_models
return {
"data": models,
"object": "list",
}
class UrlForm(BaseModel):
url: str

View File

@@ -11,20 +11,20 @@ from open_webui.apps.webui.models.models import Models
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER,
ENABLE_OPENAI_API,
MODEL_FILTER_LIST,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
AppConfig,
)
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
@@ -37,11 +37,20 @@ from open_webui.utils.payload import (
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@@ -52,69 +61,66 @@ app.add_middleware(
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
response = await call_next(request)
return response
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
}
class OpenAIConfigForm(BaseModel):
enable_openai_api: Optional[bool] = None
ENABLE_OPENAI_API: Optional[bool] = None
OPENAI_API_BASE_URLS: list[str]
OPENAI_API_KEYS: list[str]
OPENAI_API_CONFIGS: dict
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
class UrlsUpdateForm(BaseModel):
urls: list[str]
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
else:
app.state.config.OPENAI_API_KEYS += [""] * (
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
class KeysUpdateForm(BaseModel):
keys: list[str]
# Remove any extra configs
config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
if url not in config_urls:
app.state.config.OPENAI_API_CONFIGS.pop(url, None)
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models()
app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
}
@app.post("/audio/speech")
@@ -140,6 +146,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None
try:
r = requests.post(
@@ -181,10 +192,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def fetch_url(url, key):
async def aiohttp_get(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"}
headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
@@ -239,12 +250,8 @@ def merge_models_lists(model_lists):
return merged_list
def is_openai_api_disabled():
return not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list:
if is_openai_api_disabled():
async def get_all_models_responses() -> list:
if not app.state.config.ENABLE_OPENAI_API:
return []
# Check if API KEYS length is same than API URLS length
@@ -260,33 +267,69 @@ async def get_all_models_raw() -> list:
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
tasks = []
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
if url not in app.state.config.OPENAI_API_CONFIGS:
tasks.append(
aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
)
else:
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True)
model_ids = api_config.get("model_ids", [])
if enable:
if len(model_ids) == 0:
tasks.append(
aiohttp_get(
f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
)
)
else:
model_list = {
"object": "list",
"data": [
{
"id": model_id,
"name": model_id,
"owned_by": "openai",
"openai": {"id": model_id},
"urlIdx": idx,
}
for model_id in model_ids
],
}
tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
responses = await asyncio.gather(*tasks)
for idx, response in enumerate(responses):
if response:
url = app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
for model in (
response if isinstance(response, list) else response.get("data", [])
):
model["id"] = f"{prefix_id}.{model['id']}"
log.debug(f"get_all_models:responses() {responses}")
return responses
@overload
async def get_all_models(raw: Literal[True]) -> list: ...
@overload
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
async def get_all_models(raw=False) -> dict[str, list] | list:
async def get_all_models() -> dict[str, list]:
log.info("get_all_models()")
if is_openai_api_disabled():
return [] if raw else {"data": []}
responses = await get_all_models_raw()
if raw:
return responses
if not app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses()
def extract_data(response):
if response and "data" in response:
@@ -296,9 +339,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
return None
models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@@ -306,18 +347,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
models = {
"data": [],
}
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"],
)
)
return models
return models
else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx]
@@ -326,56 +361,126 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None
try:
r = requests.request(method="GET", url=f"{url}/models", headers=headers)
r.raise_for_status()
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get(f"{url}/models", headers=headers) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = r.json()
response_data = await r.json()
if "api.openai.com" in url:
# Filter the response data
response_data["data"] = [
model
for model in response_data["data"]
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
# Check if we're calling OpenAI API based on the URL
if "api.openai.com" in url:
# Filter models according to the specified conditions
response_data["data"] = [
model
for model in response_data.get("data", [])
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
]
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
models = response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
if user.role == "user":
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
models["data"] = filtered_models
return models
class ConnectionVerificationForm(BaseModel):
url: str
key: str
@app.post("/verify")
async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
):
url = form_data.url
key = form_data.key
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get(f"{url}/models", headers=headers) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except Exception:
error_detail = f"External: {e}"
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
idx = 0
payload = {**form_data}
@@ -386,6 +491,7 @@ async def generate_chat_completion(
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
# Check model info and override the payload
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@@ -394,9 +500,52 @@ async def generate_chat_completion(
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
elif not bypass_filter:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
# Attemp to get urlIdx from the model
models = await get_all_models()
# Find the model from the list
model = next(
(model for model in models["data"] if model["id"] == payload.get("model")),
None,
)
if model:
idx = model["urlIdx"]
else:
raise HTTPException(
status_code=404,
detail="Model not found",
)
# Get the API config for the model
api_config = app.state.config.OPENAI_API_CONFIGS.get(
app.state.config.OPENAI_API_BASE_URLS[idx], {}
)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# Add user info to the payload if the model is a pipeline
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {
"name": user.name,
@@ -407,8 +556,9 @@ async def generate_chat_completion(
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
is_o1 = payload["model"].lower().startswith("o1-")
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible)
if "api.openai.com" not in url and not is_o1:
if "max_completion_tokens" in payload:
@@ -437,6 +587,11 @@ async def generate_chat_completion(
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None
session = None
@@ -505,6 +660,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None
session = None

View File

@@ -159,7 +159,7 @@ class Loader:
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (

View File

@@ -0,0 +1,98 @@
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
from langchain_core.documents import Document
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_NETLOCS = {
"youtu.be",
"m.youtube.com",
"youtube.com",
"www.youtube.com",
"www.youtube-nocookie.com",
"vid.plus",
}
def _parse_video_id(url: str) -> Optional[str]:
"""Parse a YouTube URL and return the video ID if valid, otherwise None."""
parsed_url = urlparse(url)
if parsed_url.scheme not in ALLOWED_SCHEMES:
return None
if parsed_url.netloc not in ALLOWED_NETLOCS:
return None
path = parsed_url.path
if path.endswith("/watch"):
query = parsed_url.query
parsed_query = parse_qs(query)
if "v" in parsed_query:
ids = parsed_query["v"]
video_id = ids if isinstance(ids, str) else ids[0]
else:
return None
else:
path = parsed_url.path.lstrip("/")
video_id = path.split("/")[-1]
if len(video_id) != 11: # Video IDs are 11 characters long
return None
return video_id
class YoutubeLoader:
"""Load `YouTube` video transcripts."""
def __init__(
self,
video_id: str,
language: Union[str, Sequence[str]] = "en",
):
"""Initialize with YouTube video ID."""
_video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id}
self.language = language
if isinstance(language, str):
self.language = [language]
else:
self.language = language
def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects."""
try:
from youtube_transcript_api import (
NoTranscriptFound,
TranscriptsDisabled,
YouTubeTranscriptApi,
)
except ImportError:
raise ImportError(
'Could not import "youtube_transcript_api" Python package. '
"Please install it with `pip install youtube-transcript-api`."
)
try:
transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id)
except Exception as e:
print(e)
return []
try:
transcript = transcript_list.find_transcript(self.language)
except NoTranscriptFound:
transcript = transcript_list.find_transcript(["en"])
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript = " ".join(
map(
lambda transcript_piece: transcript_piece["text"].strip(" "),
transcript_pieces,
)
)
return [Document(page_content=transcript, metadata=self._metadata)]

View File

@@ -23,6 +23,7 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
# Document loaders
from open_webui.apps.retrieval.loaders.main import Loader
from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader
# Web search engines
from open_webui.apps.retrieval.web.main import SearchResult
@@ -38,6 +39,7 @@ from open_webui.apps.retrieval.web.serper import search_serper
from open_webui.apps.retrieval.web.serply import search_serply
from open_webui.apps.retrieval.web.serpstack import search_serpstack
from open_webui.apps.retrieval.web.tavily import search_tavily
from open_webui.apps.retrieval.web.bing import search_bing
from open_webui.apps.retrieval.utils import (
@@ -76,6 +78,8 @@ from open_webui.config import (
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
RAG_RELEVANCE_THRESHOLD,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
@@ -87,6 +91,7 @@ from open_webui.config import (
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_RESULT_COUNT,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
SEARXNG_QUERY_URL,
@@ -95,13 +100,20 @@ from open_webui.config import (
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
TAVILY_API_KEY,
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
TIKA_SERVER_URL,
UPLOAD_DIR,
YOUTUBE_LOADER_LANGUAGE,
DEFAULT_LOCALE,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER
from open_webui.env import (
SRC_LOG_LEVELS,
DEVICE_TYPE,
DOCKER,
)
from open_webui.utils.misc import (
calculate_sha256,
calculate_sha256_string,
@@ -111,16 +123,17 @@ from open_webui.utils.misc import (
from open_webui.utils.utils import get_admin_user, get_verified_user
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.state.config = AppConfig()
@@ -152,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
@@ -174,6 +190,10 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app.state.config.JINA_API_KEY = JINA_API_KEY
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
@@ -185,11 +205,15 @@ def update_embedding_model(
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
from sentence_transformers import SentenceTransformer
app.state.sentence_transformer_ef = SentenceTransformer(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
try:
app.state.sentence_transformer_ef = SentenceTransformer(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
except Exception as e:
log.debug(f"Error loading SentenceTransformer: {e}")
app.state.sentence_transformer_ef = None
else:
app.state.sentence_transformer_ef = None
@@ -243,8 +267,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
@@ -294,6 +326,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
"ollama_config": {
"url": app.state.config.OLLAMA_BASE_URL,
"key": app.state.config.OLLAMA_API_KEY,
},
}
@@ -310,8 +346,14 @@ class OpenAIConfigForm(BaseModel):
key: str
class OllamaConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
ollama_config: Optional[OllamaConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@@ -332,6 +374,11 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
if form_data.ollama_config is not None:
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -340,8 +387,16 @@ async def update_embedding_config(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
@@ -354,6 +409,10 @@ async def update_embedding_config(
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
"ollama_config": {
"url": app.state.config.OLLAMA_BASE_URL,
"key": app.state.config.OLLAMA_API_KEY,
},
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
@@ -414,7 +473,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
@@ -430,6 +489,9 @@ async def get_rag_config(user=Depends(get_admin_user)):
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
"seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
"jina_api_key": app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
@@ -473,6 +535,9 @@ class WebSearchConfig(BaseModel):
tavily_api_key: Optional[str] = None
searchapi_api_key: Optional[str] = None
searchapi_engine: Optional[str] = None
jina_api_key: Optional[str] = None
bing_search_v7_endpoint: Optional[str] = None
bing_search_v7_subscription_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
@@ -519,6 +584,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
if form_data.web is not None:
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
form_data.web.web_loader_ssl_verification
)
@@ -540,6 +606,15 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
app.state.config.BING_SEARCH_V7_ENDPOINT = (
form_data.web.search.bing_search_v7_endpoint
)
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
form_data.web.search.bing_search_v7_subscription_key
)
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
@@ -566,7 +641,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
@@ -582,6 +657,9 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
"searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"jina_api_key": app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
@@ -643,6 +721,23 @@ async def update_query_settings(
####################################
def _get_docs_info(docs: list[Document]) -> str:
docs_info = set()
# Trying to select relevant metadata identifying the document.
for doc in docs:
metadata = getattr(doc, "metadata", {})
doc_name = metadata.get("name", "")
if not doc_name:
doc_name = metadata.get("title", "")
if not doc_name:
doc_name = metadata.get("source", "")
if doc_name:
docs_info.add(doc_name)
return ", ".join(docs_info)
def save_docs_to_vector_db(
docs,
collection_name,
@@ -651,7 +746,9 @@ def save_docs_to_vector_db(
split: bool = True,
add: bool = False,
) -> bool:
log.info(f"save_docs_to_vector_db {docs} {collection_name}")
log.info(
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
)
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
@@ -733,8 +830,16 @@ def save_docs_to_vector_db(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
@@ -959,12 +1064,10 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u
if not collection_name:
collection_name = calculate_sha256_string(form_data.url)[:63]
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
loader = YoutubeLoader(
form_data.url, language=app.state.config.YOUTUBE_LOADER_LANGUAGE
)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {content}")
@@ -1150,7 +1253,20 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
else:
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
return search_jina(
app.state.config.JINA_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
elif engine == "bing":
return search_bing(
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
app.state.config.BING_SEARCH_V7_ENDPOINT,
str(DEFAULT_LOCALE),
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No search engine API key found in environment variables")
@@ -1180,8 +1296,12 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
docs = loader.load()
loader = get_web_loader(
urls,
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
)
docs = loader.aload()
save_docs_to_vector_db(docs, collection_name, overwrite=True)

View File

@@ -3,6 +3,7 @@ import os
import uuid
from typing import Optional, Union
import asyncio
import requests
from huggingface_hub import snapshot_download
@@ -10,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@@ -76,7 +72,7 @@ def query_doc(
limit=k,
)
log.info(f"query_doc:result {result}")
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
print(e)
@@ -127,7 +123,10 @@ def query_doc_with_hybrid_search(
"metadatas": [[d.metadata for d in result]],
}
log.info(f"query_doc_with_hybrid_search:result {result}")
log.info(
"query_doc_with_hybrid_search:result "
+ f'{result["metadatas"]} {result["distances"]}'
)
return result
except Exception as e:
raise e
@@ -178,35 +177,34 @@ def merge_and_sort_query_results(
def query_collection(
collection_names: list[str],
query: str,
queries: list[str],
embedding_function,
k: int,
) -> dict:
results = []
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
for query in queries:
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: list[str],
query: str,
queries: list[str],
embedding_function,
k: int,
reranking_function,
@@ -216,15 +214,16 @@ def query_collection_with_hybrid_search(
error = False
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
for query in queries:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
@@ -281,8 +280,8 @@ def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
url,
key,
embedding_batch_size,
):
if embedding_engine == "":
@@ -292,8 +291,8 @@ def get_embedding_function(
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
url=url,
key=key,
)
def generate_multiple(query, func):
@@ -310,15 +309,14 @@ def get_embedding_function(
def get_rag_context(
files,
messages,
queries,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
):
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
extracted_collections = []
relevant_contexts = []
@@ -360,7 +358,7 @@ def get_rag_context(
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
@@ -375,7 +373,7 @@ def get_rag_context(
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
queries=queries,
embedding_function=embedding_function,
k=k,
)
@@ -467,7 +465,7 @@ def get_model_path(model: str, update_model: bool = False):
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@@ -489,29 +487,50 @@ def generate_openai_batch_embeddings(
return None
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
print(data)
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
**{"model": model, "texts": text, "url": url, "key": key}
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
**{"model": model, "texts": [text], "url": url, "key": key}
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
embeddings = generate_openai_batch_embeddings(model, text, url, key)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
return embeddings[0] if isinstance(text, str) else embeddings

View File

@@ -8,6 +8,14 @@ elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient

View File

@@ -13,11 +13,24 @@ from open_webui.config import (
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
class ChromaClient:
def __init__(self):
settings_dict = {
"allow_reset": True,
"anonymized_telemetry": False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = (
CHROMA_CLIENT_AUTH_CREDENTIALS
)
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
@@ -26,12 +39,12 @@ class ChromaClient:
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
settings=Settings(**settings_dict),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
settings=Settings(**settings_dict),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)

View File

@@ -0,0 +1,178 @@
from opensearchpy import OpenSearch
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_USERNAME,
OPENSEARCH_PASSWORD,
)
class OpenSearchClient:
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(
hosts=[OPENSEARCH_URI],
use_ssl=OPENSEARCH_SSL,
verify_certs=OPENSEARCH_CERT_VERIFY,
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
def _create_index(self, index_name: str, dimension: int):
body = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
"m": 16,
},
},
"text": {"type": "text"},
"metadata": {"type": "object"},
}
}
}
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
def delete_colleciton(self, index_name: str):
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
def search(
self, index_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_search_result(result)
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def upsert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def delete(self, index_name: str, ids: list[str]):
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
for index in indices:
self.client.indices.delete(index=index)

View File

@@ -0,0 +1,354 @@
from typing import Optional, List, Dict, Any
from sqlalchemy import (
cast,
column,
create_engine,
Column,
Integer,
select,
text,
Text,
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL
VECTOR_LENGTH = 1536
Base = declarative_base()
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.apps.webui.internal.db import Session
self.session = Session
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
self.session = scoped_session(SessionLocal)
try:
# Ensure the pgvector extension is available
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# Create the tables if they do not exist
# Base.metadata.create_all requires a bind (engine or connection)
# Get the connection from the session
connection = self.session.connection()
Base.metadata.create_all(bind=connection)
# Create an index on the vector column if it doesn't exist
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
)
)
self.session.commit()
print("Initialization complete.")
except Exception as e:
self.session.rollback()
print(f"Error during initialization: {e}")
raise
def adjust_vector_length(self, vector: List[float]) -> List[float]:
# Adjust vector to have length VECTOR_LENGTH
current_length = len(vector)
if current_length < VECTOR_LENGTH:
# Pad the vector with zeros
vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH:
raise Exception(
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
)
return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
print(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during upsert: {e}")
raise
def search(
self,
collection_name: str,
vectors: List[List[float]],
limit: Optional[int] = None,
) -> Optional[SearchResult]:
try:
if not vectors:
return None
# Adjust query vectors to VECTOR_LENGTH
vectors = [self.adjust_vector_length(vector) for vector in vectors]
num_queries = len(vectors)
def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH))
# Create the values for query vectors
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
)
# Build the lateral subquery for each query vector
subq = (
select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
# Build the main query by joining query_vectors and the lateral subquery
stmt = (
select(
query_vectors.c.qid,
subq.c.id,
subq.c.text,
subq.c.vmetadata,
subq.c.distance,
)
.select_from(query_vectors)
.join(subq, true())
.order_by(query_vectors.c.qid, subq.c.distance)
)
result_proxy = self.session.execute(stmt)
results = result_proxy.all()
ids = [[] for _ in range(num_queries)]
distances = [[] for _ in range(num_queries)]
documents = [[] for _ in range(num_queries)]
metadatas = [[] for _ in range(num_queries)]
if not results:
return SearchResult(
ids=ids,
distances=distances,
documents=documents,
metadatas=metadatas,
)
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append(row.distance)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
print(f"Error during search: {e}")
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(
ids=ids,
documents=documents,
metadatas=metadatas,
)
except Exception as e:
print(f"Error during query: {e}")
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
print(f"Error during get: {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
print(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during delete: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
print(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
self.session.rollback()
print(f"Error during reset: {e}")
raise
def close(self) -> None:
pass
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
is not None
)
return exists
except Exception as e:
print(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
print(f"Collection '{collection_name}' deleted.")

View File

@@ -5,7 +5,7 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
NO_LIMIT = 999999999
@@ -14,7 +14,12 @@ class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = (
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
if self.QDRANT_URI
else None
)
def _result_to_get_result(self, points) -> GetResult:
ids = []

View File

@@ -0,0 +1,73 @@
import logging
import os
from pprint import pprint
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
import argparse
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
"""
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
"""
def search_bing(
subscription_key: str,
endpoint: str,
locale: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
mkt = locale
params = {"q": query, "mkt": mkt, "answerCount": count}
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
try:
response = requests.get(endpoint, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("webPages", {}).get("value", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("name"),
snippet=result.get("snippet"),
)
for result in results
]
except Exception as ex:
log.error(f"Error: {ex}")
raise ex
def main():
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
parser.add_argument(
"query",
type=str,
default="Top 10 international news today",
help="The search query.",
)
parser.add_argument(
"--count", type=int, default=10, help="Number of search results to return."
)
parser.add_argument(
"--filter", nargs="*", help="List of filters to apply to the search results."
)
parser.add_argument(
"--locale",
type=str,
default="en-US",
help="The locale to use for the search, maps to market in api",
)
args = parser.parse_args()
results = search_bing(args.locale, args.query, args.count, args.filter)
pprint(results)

View File

@@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(query: str, count: int) -> list[SearchResult]:
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
@@ -20,9 +20,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
}
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
response.raise_for_status()

View File

@@ -0,0 +1,58 @@
{
"_type": "SearchResponse",
"queryContext": {
"originalQuery": "Top 10 international results"
},
"webPages": {
"webSearchUrl": "https://www.bing.com/search?q=Top+10+international+results",
"totalEstimatedMatches": 687,
"value": [
{
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.0",
"name": "2024 Mexican Grand Prix - F1 results and latest standings ... - PlanetF1",
"url": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
"datePublished": "2024-10-27T00:00:00.0000000",
"datePublishedFreshnessText": "1 day ago",
"isFamilyFriendly": true,
"displayUrl": "https://www.planetf1.com/news/f1-results-2024-mexican-grand-prix-race-standings",
"snippet": "Nico Hulkenberg and Pierre Gasly completed the top 10. A full report of the Mexican Grand Prix is available at the bottom of this article. F1 results 2024 Mexican Grand Prix",
"dateLastCrawled": "2024-10-28T07:15:00.0000000Z",
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=916492551782&mkt=en-US&setlang=en-US&w=zBsfaAPyF2tUrHFHr_vFFdUm8sng4g34",
"language": "en",
"isNavigational": false,
"noCache": false
},
{
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.1",
"name": "F1 Results Today: HUGE Verstappen penalties cause major title change",
"url": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max-verstappen-penalties-cause-major-title-change/",
"datePublished": "2024-10-27T00:00:00.0000000",
"datePublishedFreshnessText": "1 day ago",
"isFamilyFriendly": true,
"displayUrl": "https://www.gpfans.com/en/f1-news/1033512/f1-results-today-mexican-grand-prix-huge-max...",
"snippet": "Elsewhere, Mercedes duo Lewis Hamilton and George Russell came home in P4 and P5 respectively. Meanwhile, the surprise package of the day were Haas, with both Kevin Magnussen and Nico Hulkenberg finishing inside the points.. READ MORE: RB star issues apology after red flag CRASH at Mexican GP Mexican Grand Prix 2024 results. 1. Carlos Sainz [Ferrari] 2. Lando Norris [McLaren] - +4.705",
"dateLastCrawled": "2024-10-28T06:06:00.0000000Z",
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=2840656522642&mkt=en-US&setlang=en-US&w=-Tbkwxnq52jZCvG7l3CtgcwT1vwAjIUD",
"language": "en",
"isNavigational": false,
"noCache": false
},
{
"id": "https://api.bing.microsoft.com/api/v7/#WebPages.2",
"name": "International Power Rankings: England flying, Kangaroos cruising, Fiji rise",
"url": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying-kangaroos-cruising-fiji-rise",
"datePublished": "2024-10-28T00:00:00.0000000",
"datePublishedFreshnessText": "7 hours ago",
"isFamilyFriendly": true,
"displayUrl": "https://www.loverugbyleague.com/post/international-power-rankings-england-flying...",
"snippet": "LRL RECOMMENDS: England player ratings from first Test against Samoa as omnificent George Williams scores perfect 10. 2. Australia (Men) SAME. The Kangaroos remain 2nd in our Power Rankings after their 22-10 win against New Zealand in Christchurch on Sunday. As was the case in their win against Tonga last week, Mal Meningas side weren ...",
"dateLastCrawled": "2024-10-28T07:09:00.0000000Z",
"cachedPageUrl": "https://cc.bingj.com/cache.aspx?q=Top+10+international+results&d=1535008462672&mkt=en-US&setlang=en-US&w=82ujhH4Kp0iuhCS7wh1xLUFYUeetaVVm",
"language": "en",
"isNavigational": false,
"noCache": false
}
],
"someResultsRemoved": true
}
}

View File

@@ -1,3 +1,5 @@
# TODO: move socket to webui app
import asyncio
import socketio
import logging

View File

@@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import (
chats,
folders,
configs,
groups,
files,
functions,
memories,
@@ -34,6 +35,7 @@ from open_webui.config import (
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
ENABLE_API_KEY,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
@@ -50,9 +52,22 @@ from open_webui.config import (
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_BANNERS,
ENABLE_LDAP,
LDAP_SERVER_LABEL,
LDAP_SERVER_HOST,
LDAP_SERVER_PORT,
LDAP_ATTRIBUTE_FOR_USERNAME,
LDAP_SEARCH_FILTERS,
LDAP_SEARCH_BASE,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
LDAP_USE_TLS,
LDAP_CA_CERT_FILE,
LDAP_CIPHERS,
AppConfig,
)
from open_webui.env import (
ENV,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
)
@@ -72,7 +87,11 @@ from open_webui.utils.payload import (
from open_webui.utils.tools import get_tools
app = FastAPI()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
log = logging.getLogger(__name__)
@@ -80,6 +99,8 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
@@ -92,6 +113,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
@@ -111,7 +134,19 @@ app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.MODELS = {}
app.state.config.ENABLE_LDAP = ENABLE_LDAP
app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
app.state.config.LDAP_APP_DN = LDAP_APP_DN
app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
@@ -135,13 +170,15 @@ app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(groups.router, prefix="/groups", tags=["groups"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@@ -336,7 +373,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):
return params
async def generate_function_chat_completion(form_data, user):
async def generate_function_chat_completion(form_data, user, models: dict = {}):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
@@ -372,6 +409,7 @@ async def generate_function_chat_completion(form_data, user):
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
}
extra_params["__tools__"] = get_tools(
app,
@@ -379,7 +417,7 @@ async def generate_function_chat_completion(form_data, user):
user,
{
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
},

View File

@@ -64,6 +64,11 @@ class SigninForm(BaseModel):
password: str
class LdapForm(BaseModel):
user: str
password: str
class ProfileImageUrlForm(BaseModel):
profile_image_url: str

View File

@@ -203,15 +203,22 @@ class ChatTable:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
print("update_shared_chat_by_id")
chat = db.get(Chat, chat_id)
print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
)
return self.get_chat_by_id(chat.share_id)
if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id)
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.updated_at = int(time.time())
db.commit()
db.refresh(shared_chat)
return ChatModel.model_validate(shared_chat)
except Exception:
return None

View File

@@ -1,157 +0,0 @@
import json
import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Documents DB Schema
####################
class Document(Base):
__tablename__ = "document"
collection_name = Column(String, primary_key=True)
name = Column(String, unique=True)
title = Column(Text)
filename = Column(Text)
content = Column(Text, nullable=True)
user_id = Column(String)
timestamp = Column(BigInteger)
class DocumentModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
collection_name: str
name: str
title: str
filename: str
content: Optional[str] = None
user_id: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class DocumentResponse(BaseModel):
collection_name: str
name: str
title: str
filename: str
content: Optional[dict] = None
user_id: str
timestamp: int # timestamp in epoch
class DocumentUpdateForm(BaseModel):
name: str
title: str
class DocumentForm(DocumentUpdateForm):
collection_name: str
filename: str
content: Optional[str] = None
class DocumentsTable:
def insert_new_doc(
self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]:
with get_db() as db:
document = DocumentModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"timestamp": int(time.time()),
}
)
try:
result = Document(**document.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return DocumentModel.model_validate(result)
else:
return None
except Exception:
return None
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try:
with get_db() as db:
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except Exception:
return None
def get_docs(self) -> list[DocumentModel]:
with get_db() as db:
return [
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]:
try:
with get_db() as db:
db.query(Document).filter_by(name=name).update(
{
"title": form_data.title,
"name": form_data.name,
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e:
log.exception(e)
return None
def update_doc_content_by_name(
self, name: str, updated: dict
) -> Optional[DocumentModel]:
try:
doc = self.get_doc_by_name(name)
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
with get_db() as db:
db.query(Document).filter_by(name=name).update(
{
"content": json.dumps(doc_content),
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(name)
except Exception as e:
log.exception(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
try:
with get_db() as db:
db.query(Document).filter_by(name=name).delete()
db.commit()
return True
except Exception:
return False
Documents = DocumentsTable()

View File

@@ -0,0 +1,186 @@
import json
import logging
import time
from typing import Optional
import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# UserGroup DB Schema
####################
class Group(Base):
__tablename__ = "group"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
name = Column(Text)
description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class GroupResponse(BaseModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel):
name: str
description: str
class GroupUpdateForm(GroupForm):
permissions: Optional[dict] = None
user_ids: Optional[list[str]] = None
admin_ids: Optional[list[str]] = None
class GroupTable:
def insert_new_group(
self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]:
with get_db() as db:
group = GroupModel(
**{
**form_data.model_dump(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Group(**group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return GroupModel.model_validate(result)
else:
return None
except Exception:
return None
def get_groups(self) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None
except Exception:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_group_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def delete_group_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_groups(self) -> bool:
with get_db() as db:
try:
db.query(Group).delete()
db.commit()
return True
except Exception:
return False
Groups = GroupTable()

View File

@@ -8,11 +8,13 @@ from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from open_webui.apps.webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -34,6 +36,23 @@ class Knowledge(Base):
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@@ -50,6 +69,8 @@ class KnowledgeModel(BaseModel):
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -59,15 +80,15 @@ class KnowledgeModel(BaseModel):
####################
class KnowledgeResponse(BaseModel):
id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class KnowledgeUserModel(KnowledgeModel):
user: Optional[UserResponse] = None
class KnowledgeResponse(KnowledgeModel):
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeUserResponse(KnowledgeUserModel):
files: Optional[list[FileMetadataResponse | dict]] = None
@@ -75,12 +96,7 @@ class KnowledgeForm(BaseModel):
name: str
description: str
data: Optional[dict] = None
class KnowledgeUpdateForm(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
data: Optional[dict] = None
access_control: Optional[dict] = None
class KnowledgeTable:
@@ -110,14 +126,33 @@ class KnowledgeTable:
except Exception:
return None
def get_knowledge_items(self) -> list[KnowledgeModel]:
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
with get_db() as db:
return [
KnowledgeModel.model_validate(knowledge)
for knowledge in db.query(Knowledge)
.order_by(Knowledge.updated_at.desc())
.all()
]
knowledge_bases = []
for knowledge in (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
):
user = Users.get_user_by_id(knowledge.user_id)
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
**KnowledgeModel.model_validate(knowledge).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return knowledge_bases
def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases()
return [
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control)
]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
@@ -128,14 +163,32 @@ class KnowledgeTable:
return None
def update_knowledge_by_id(
self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
**form_data.model_dump(),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def update_knowledge_data_by_id(
self, id: str, data: dict
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
"data": data,
"updated_at": int(time.time()),
}
)

View File

@@ -4,8 +4,19 @@ from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -67,6 +78,25 @@ class Model(Base):
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
is_active = Column(Boolean, default=True)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
@@ -80,6 +110,9 @@ class ModelModel(BaseModel):
params: ModelParams
meta: ModelMeta
access_control: Optional[dict] = None
is_active: bool
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@@ -91,12 +124,12 @@ class ModelModel(BaseModel):
####################
class ModelResponse(BaseModel):
id: str
name: str
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ModelUserResponse(ModelModel):
user: Optional[UserResponse] = None
class ModelResponse(ModelModel):
pass
class ModelForm(BaseModel):
@@ -105,6 +138,8 @@ class ModelForm(BaseModel):
name: str
meta: ModelMeta
params: ModelParams
access_control: Optional[dict] = None
is_active: bool = True
class ModelsTable:
@@ -138,6 +173,39 @@ class ModelsTable:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelUserResponse]:
with get_db() as db:
models = []
for model in db.query(Model).filter(Model.base_model_id != None).all():
user = Users.get_user_by_id(model.user_id)
models.append(
ModelUserResponse.model_validate(
{
**ModelModel.model_validate(model).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return models
def get_base_models(self) -> list[ModelModel]:
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
models = self.get_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:
@@ -146,6 +214,23 @@ class ModelsTable:
except Exception:
return None
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
with get_db() as db:
try:
is_active = db.query(Model).filter_by(id=id).first().is_active
db.query(Model).filter_by(id=id).update(
{
"is_active": not is_active,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_model_by_id(id)
except Exception:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
with get_db() as db:
@@ -153,7 +238,7 @@ class ModelsTable:
result = (
db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}, exclude_none=True))
.update(model.model_dump(exclude={"id"}))
)
db.commit()
@@ -175,5 +260,15 @@ class ModelsTable:
except Exception:
return False
def delete_all_models(self) -> bool:
try:
with get_db() as db:
db.query(Model).delete()
db.commit()
return True
except Exception:
return False
Models = ModelsTable()

View File

@@ -2,8 +2,12 @@ import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
####################
# Prompts DB Schema
@@ -19,6 +23,23 @@ class Prompt(Base):
content = Column(Text)
timestamp = Column(BigInteger)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
class PromptModel(BaseModel):
command: str
@@ -27,6 +48,7 @@ class PromptModel(BaseModel):
content: str
timestamp: int # timestamp in epoch
access_control: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
@@ -35,10 +57,15 @@ class PromptModel(BaseModel):
####################
class PromptUserResponse(PromptModel):
user: Optional[UserResponse] = None
class PromptForm(BaseModel):
command: str
title: str
content: str
access_control: Optional[dict] = None
class PromptsTable:
@@ -48,16 +75,14 @@ class PromptsTable:
prompt = PromptModel(
**{
"user_id": user_id,
"command": form_data.command,
"title": form_data.title,
"content": form_data.content,
**form_data.model_dump(),
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = Prompt(**prompt.dict())
result = Prompt(**prompt.model_dump())
db.add(result)
db.commit()
db.refresh(result)
@@ -76,11 +101,34 @@ class PromptsTable:
except Exception:
return None
def get_prompts(self) -> list[PromptModel]:
def get_prompts(self) -> list[PromptUserResponse]:
with get_db() as db:
return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
prompts = []
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
user = Users.get_user_by_id(prompt.user_id)
prompts.append(
PromptUserResponse.model_validate(
{
**PromptModel.model_validate(prompt).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return prompts
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]:
prompts = self.get_prompts()
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
@@ -90,6 +138,7 @@ class PromptsTable:
prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title
prompt.content = form_data.content
prompt.access_control = form_data.access_control
prompt.timestamp = int(time.time())
db.commit()
return PromptModel.model_validate(prompt)

View File

@@ -3,10 +3,13 @@ import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -26,6 +29,24 @@ class Tool(Base):
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
@@ -42,6 +63,8 @@ class ToolModel(BaseModel):
content: str
specs: list[dict]
meta: ToolMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@@ -58,15 +81,21 @@ class ToolResponse(BaseModel):
user_id: str
name: str
meta: ToolMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
access_control: Optional[dict] = None
class ToolValves(BaseModel):
@@ -109,9 +138,32 @@ class ToolsTable:
except Exception:
return None
def get_tools(self) -> list[ToolModel]:
def get_tools(self) -> list[ToolUserResponse]:
with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
tools = []
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = Users.get_user_by_id(tool.user_id)
tools.append(
ToolUserResponse.model_validate(
{
**ToolModel.model_validate(tool).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return tools
def get_tools_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ToolUserResponse]:
tools = self.get_tools()
return [
tool
for tool in tools
if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control)
]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:

View File

@@ -62,6 +62,14 @@ class UserModel(BaseModel):
####################
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str
profile_image_url: str
class UserRoleUpdateForm(BaseModel):
id: str
role: str

View File

@@ -2,12 +2,14 @@ import re
import uuid
import time
import datetime
import logging
from open_webui.apps.webui.models.auths import (
AddUserForm,
ApiKey,
Auths,
Token,
LdapForm,
SigninForm,
SigninResponse,
SignupForm,
@@ -16,13 +18,15 @@ from open_webui.apps.webui.models.auths import (
UserResponse,
)
from open_webui.apps.webui.models.users import Users
from open_webui.config import WEBUI_AUTH
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
@@ -37,10 +41,19 @@ from open_webui.utils.utils import (
get_password_hash,
)
from open_webui.utils.webhook import post_webhook
from typing import Optional
from open_webui.utils.access_control import get_permissions
from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ldap3 import Server, Connection, ALL, Tls
from ldap3.utils.conv import escape_filter_chars
router = APIRouter()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
############################
# GetSessionUser
############################
@@ -48,6 +61,7 @@ router = APIRouter()
class SessionUserResponse(Token, UserResponse):
expires_at: Optional[int] = None
permissions: Optional[dict] = None
@router.get("/", response_model=SessionUserResponse)
@@ -80,6 +94,10 @@ async def get_session_user(
secure=WEBUI_SESSION_COOKIE_SECURE,
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@@ -89,6 +107,7 @@ async def get_session_user(
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
@@ -137,6 +156,140 @@ async def update_password(
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
############################
# LDAP Authentication
############################
@router.post("/ldap", response_model=SigninResponse)
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME
LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE
LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS
LDAP_APP_DN = request.app.state.config.LDAP_APP_DN
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
LDAP_CIPHERS = (
request.app.state.config.LDAP_CIPHERS
if request.app.state.config.LDAP_CIPHERS
else "ALL"
)
if not ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled")
try:
tls = Tls(
validate=CERT_REQUIRED,
version=PROTOCOL_TLS,
ca_certs_file=LDAP_CA_CERT_FILE,
ciphers=LDAP_CIPHERS,
)
except Exception as e:
log.error(f"An error occurred on TLS: {str(e)}")
raise HTTPException(400, detail=str(e))
try:
server = Server(
host=LDAP_SERVER_HOST,
port=LDAP_SERVER_PORT,
get_info=ALL,
use_ssl=LDAP_USE_TLS,
tls=tls,
)
connection_app = Connection(
server,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
auto_bind="NONE",
authentication="SIMPLE",
)
if not connection_app.bind():
raise HTTPException(400, detail="Application account bind failed")
search_success = connection_app.search(
search_base=LDAP_SEARCH_BASE,
search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
)
if not search_success:
raise HTTPException(400, detail="User not found in the LDAP server")
entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
mail = str(entry["mail"])
cn = str(entry["cn"])
user_dn = entry.entry_dn
if username == form_data.user.lower():
connection_user = Connection(
server,
user_dn,
form_data.password,
auto_bind="NONE",
authentication="SIMPLE",
)
if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}")
user = Users.get_user_by_email(mail)
if not user:
try:
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(mail, hashed, cn)
if not user:
raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
)
except HTTPException:
raise
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
user = Auths.authenticate_user(mail, password=str(form_data.password))
if user:
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(
request.app.state.config.JWT_EXPIRES_IN
),
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
return {
"token": token,
"token_type": "Bearer",
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
raise HTTPException(
400,
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
)
except Exception as e:
raise HTTPException(400, detail=str(e))
############################
# SignIn
############################
@@ -211,6 +364,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
secure=WEBUI_SESSION_COOKIE_SECURE,
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@@ -220,6 +377,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@@ -260,6 +418,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
)
if Users.get_num_users() == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
@@ -307,6 +470,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
},
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@@ -316,6 +483,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
@@ -413,6 +581,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -423,6 +592,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
ENABLE_SIGNUP: bool
ENABLE_API_KEY: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
@@ -435,6 +605,7 @@ async def update_admin_config(
):
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -453,6 +624,7 @@ async def update_admin_config(
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@@ -460,6 +632,105 @@ async def update_admin_config(
}
class LdapServerConfig(BaseModel):
label: str
host: str
port: Optional[int] = None
attribute_for_username: str = "uid"
app_dn: str
app_dn_password: str
search_base: str
search_filters: str = ""
use_tls: bool = True
certificate_path: Optional[str] = None
ciphers: Optional[str] = "ALL"
@router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
return {
"label": request.app.state.config.LDAP_SERVER_LABEL,
"host": request.app.state.config.LDAP_SERVER_HOST,
"port": request.app.state.config.LDAP_SERVER_PORT,
"attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
"app_dn": request.app.state.config.LDAP_APP_DN,
"app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
"search_base": request.app.state.config.LDAP_SEARCH_BASE,
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}
@router.post("/admin/config/ldap/server")
async def update_ldap_server(
request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
):
required_fields = [
"label",
"host",
"attribute_for_username",
"app_dn",
"app_dn_password",
"search_base",
]
for key in required_fields:
value = getattr(form_data, key)
if not value:
raise HTTPException(400, detail=f"Required field {key} is empty")
if form_data.use_tls and not form_data.certificate_path:
raise HTTPException(
400, detail="TLS is enabled but certificate file path is missing"
)
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
request.app.state.config.LDAP_SERVER_HOST = form_data.host
request.app.state.config.LDAP_SERVER_PORT = form_data.port
request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
form_data.attribute_for_username
)
request.app.state.config.LDAP_APP_DN = form_data.app_dn
request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
return {
"label": request.app.state.config.LDAP_SERVER_LABEL,
"host": request.app.state.config.LDAP_SERVER_HOST,
"port": request.app.state.config.LDAP_SERVER_PORT,
"attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME,
"app_dn": request.app.state.config.LDAP_APP_DN,
"app_dn_password": request.app.state.config.LDAP_APP_PASSWORD,
"search_base": request.app.state.config.LDAP_SEARCH_BASE,
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}
@router.get("/admin/config/ldap")
async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
class LdapConfigForm(BaseModel):
enable_ldap: Optional[bool] = None
@router.post("/admin/config/ldap")
async def update_ldap_config(
request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
############################
# API Key
############################
@@ -467,9 +738,16 @@ async def update_admin_config(
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
async def generate_api_key(request: Request, user=Depends(get_current_user)):
if not request.app.state.config.ENABLE_API_KEY:
raise HTTPException(
status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
)
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,

View File

@@ -17,7 +17,10 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -50,9 +53,10 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get(
"chat", {}
).get("deletion", {}):
if user.role == "user" and not has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@@ -385,8 +389,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
return result
else:
if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
"deletion", {}
if not has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View File

@@ -0,0 +1,120 @@
import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.groups import (
Groups,
GroupForm,
GroupUpdateForm,
GroupResponse,
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
############################
# GetFunctions
############################
@router.get("/", response_model=list[GroupResponse])
async def get_groups(user=Depends(get_verified_user)):
if user.role == "admin":
return Groups.get_groups()
else:
return Groups.get_groups_by_member_id(user.id)
############################
# CreateNewGroup
############################
@router.post("/create", response_model=Optional[GroupResponse])
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
try:
group = Groups.insert_new_group(user.id, form_data)
if group:
return group
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# GetGroupById
############################
@router.get("/id/{id}", response_model=Optional[GroupResponse])
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id)
if group:
return group
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateGroupById
############################
@router.post("/id/{id}/update", response_model=Optional[GroupResponse])
async def update_group_by_id(
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
):
try:
group = Groups.update_group_by_id(id, form_data)
if group:
return group
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteGroupById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
try:
result = Groups.delete_group_by_id(id)
if result:
return result
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)

View File

@@ -1,14 +1,14 @@
import json
from typing import Optional, Union
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging
from open_webui.apps.webui.models.knowledge import (
Knowledges,
KnowledgeUpdateForm,
KnowledgeForm,
KnowledgeResponse,
KnowledgeUserResponse,
)
from open_webui.apps.webui.models.files import Files, FileModel
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
@@ -17,6 +17,9 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
@@ -26,64 +29,103 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetKnowledgeItems
# getKnowledgeBases
############################
@router.get(
"/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]]
)
async def get_knowledge_items(
id: Optional[str] = None, user=Depends(get_verified_user)
):
if id:
knowledge = Knowledges.get_knowledge_by_id(id=id)
@router.get("/", response_model=list[KnowledgeUserResponse])
async def get_knowledge(user=Depends(get_verified_user)):
knowledge_bases = []
if knowledge:
return knowledge
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if user.role == "admin":
knowledge_bases = Knowledges.get_knowledge_bases()
else:
knowledge_bases = []
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
for knowledge in Knowledges.get_knowledge_items():
files = []
if knowledge.data:
files = Files.get_file_metadatas_by_ids(
knowledge.data.get("file_ids", [])
)
# Check if all files exist
if len(files) != len(knowledge.data.get("file_ids", [])):
missing_files = list(
set(knowledge.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_by_id(
id=knowledge.id, form_data=KnowledgeUpdateForm(data=data)
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_bases.append(
KnowledgeResponse(
**knowledge.model_dump(),
files=files,
)
# Get files for each knowledge base
knowledge_with_files = []
for knowledge_base in knowledge_bases:
files = []
if knowledge_base.data:
files = Files.get_file_metadatas_by_ids(
knowledge_base.data.get("file_ids", [])
)
return knowledge_bases
# Check if all files exist
if len(files) != len(knowledge_base.data.get("file_ids", [])):
missing_files = list(
set(knowledge_base.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge_base.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data=data
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_with_files.append(
KnowledgeUserResponse(
**knowledge_base.model_dump(),
files=files,
)
)
return knowledge_with_files
@router.get("/list", response_model=list[KnowledgeUserResponse])
async def get_knowledge_list(user=Depends(get_verified_user)):
knowledge_bases = []
if user.role == "admin":
knowledge_bases = Knowledges.get_knowledge_bases()
else:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
# Get files for each knowledge base
knowledge_with_files = []
for knowledge_base in knowledge_bases:
files = []
if knowledge_base.data:
files = Files.get_file_metadatas_by_ids(
knowledge_base.data.get("file_ids", [])
)
# Check if all files exist
if len(files) != len(knowledge_base.data.get("file_ids", [])):
missing_files = list(
set(knowledge_base.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge_base.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_data_by_id(
id=knowledge_base.id, data=data
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_with_files.append(
KnowledgeUserResponse(
**knowledge_base.model_dump(),
files=files,
)
)
return knowledge_with_files
############################
@@ -92,7 +134,17 @@ async def get_knowledge_items(
@router.post("/create", response_model=Optional[KnowledgeResponse])
async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)):
async def create_new_knowledge(
request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user)
):
if user.role != "admin" and not has_permission(
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
if knowledge:
@@ -118,13 +170,20 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.get_knowledge_by_id(id=id)
if knowledge:
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=files,
)
if (
user.role == "admin"
or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control)
):
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=files,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -140,11 +199,23 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
async def update_knowledge_by_id(
id: str,
form_data: KnowledgeUpdateForm,
user=Depends(get_admin_user),
form_data: KnowledgeForm,
user=Depends(get_verified_user),
):
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
if knowledge:
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids)
@@ -173,9 +244,22 @@ class KnowledgeFileIdForm(BaseModel):
def add_file_to_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
file = Files.get_file_by_id(form_data.file_id)
if not file:
raise HTTPException(
@@ -206,9 +290,7 @@ def add_file_to_knowledge_by_id(
file_ids.append(form_data.file_id)
data["file_ids"] = file_ids
knowledge = Knowledges.update_knowledge_by_id(
id=id, form_data=KnowledgeUpdateForm(data=data)
)
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge:
files = Files.get_files_by_ids(file_ids)
@@ -238,9 +320,21 @@ def add_file_to_knowledge_by_id(
def update_file_from_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
file = Files.get_file_by_id(form_data.file_id)
if not file:
raise HTTPException(
@@ -288,9 +382,21 @@ def update_file_from_knowledge_by_id(
def remove_file_from_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
file = Files.get_file_by_id(form_data.file_id)
if not file:
raise HTTPException(
@@ -318,9 +424,7 @@ def remove_file_from_knowledge_by_id(
file_ids.remove(form_data.file_id)
data["file_ids"] = file_ids
knowledge = Knowledges.update_knowledge_by_id(
id=id, form_data=KnowledgeUpdateForm(data=data)
)
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge:
files = Files.get_files_by_ids(file_ids)
@@ -346,32 +450,26 @@ def remove_file_from_knowledge_by_id(
)
############################
# ResetKnowledgeById
############################
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)):
try:
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
except Exception as e:
log.debug(e)
pass
knowledge = Knowledges.update_knowledge_by_id(
id=id, form_data=KnowledgeUpdateForm(data={"file_ids": []})
)
return knowledge
############################
# DeleteKnowledgeById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
try:
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
except Exception as e:
@@ -379,3 +477,34 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
pass
result = Knowledges.delete_knowledge_by_id(id=id)
return result
############################
# ResetKnowledgeById
############################
@router.post("/{id}/reset", response_model=Optional[KnowledgeResponse])
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
try:
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
except Exception as e:
log.debug(e)
pass
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
return knowledge

View File

@@ -4,53 +4,71 @@ from open_webui.apps.webui.models.models import (
ModelForm,
ModelModel,
ModelResponse,
ModelUserResponse,
Models,
)
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
router = APIRouter()
###########################
# getModels
# GetModels
###########################
@router.get("/", response_model=list[ModelResponse])
@router.get("/", response_model=list[ModelUserResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if id:
model = Models.get_model_by_id(id)
if model:
return [model]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if user.role == "admin":
return Models.get_models()
else:
return Models.get_all_models()
return Models.get_models_by_user_id(user.id)
###########################
# GetBaseModels
###########################
@router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models()
############################
# AddNewModel
# CreateNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
@router.post("/create", response_model=Optional[ModelModel])
async def create_new_model(
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
if form_data.id in request.app.state.MODELS:
if user.role != "admin" and not has_permission(
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
model = Models.get_model_by_id(form_data.id)
if model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
@@ -60,37 +78,85 @@ async def add_new_model(
)
###########################
# GetModelById
###########################
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
):
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# ToggelModelById
############################
@router.post("/model/toggle", response_model=Optional[ModelResponse])
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "write", model.access_control)
):
model = Models.toggle_model_by_id(id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/update", response_model=Optional[ModelModel])
@router.post("/model/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
model = Models.update_model_by_id(id, form_data)
return model
############################
@@ -98,7 +164,26 @@ async def update_model_by_id(
############################
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
@router.delete("/model/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if model.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Models.delete_model_by_id(id)
return result
@router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)):
result = Models.delete_all_models()
return result

View File

@@ -1,9 +1,15 @@
from typing import Optional
from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompts
from open_webui.apps.webui.models.prompts import (
PromptForm,
PromptUserResponse,
PromptModel,
Prompts,
)
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
router = APIRouter()
@@ -14,7 +20,22 @@ router = APIRouter()
@router.get("/", response_model=list[PromptModel])
async def get_prompts(user=Depends(get_verified_user)):
return Prompts.get_prompts()
if user.role == "admin":
prompts = Prompts.get_prompts()
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
return prompts
@router.get("/list", response_model=list[PromptUserResponse])
async def get_prompt_list(user=Depends(get_verified_user)):
if user.role == "admin":
prompts = Prompts.get_prompts()
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
return prompts
############################
@@ -23,7 +44,17 @@ async def get_prompts(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
async def create_new_prompt(
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
):
if user.role != "admin" and not has_permission(
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt is None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
@@ -50,7 +81,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
return prompt
if (
user.role == "admin"
or prompt.user_id == user.id
or has_access(user.id, "read", prompt.access_control)
):
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -67,8 +103,21 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
async def update_prompt_by_command(
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if not prompt:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if prompt.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
@@ -85,6 +134,19 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if not prompt:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if prompt.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Prompts.delete_prompt_by_command(f"/{command}")
return result

View File

@@ -2,50 +2,82 @@ import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports
from open_webui.apps.webui.models.tools import (
ToolForm,
ToolModel,
ToolResponse,
ToolUserResponse,
Tools,
)
from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
from open_webui.config import CACHE_DIR, DATA_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
router = APIRouter()
############################
# GetToolkits
# GetTools
############################
@router.get("/", response_model=list[ToolResponse])
async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "read")
return tools
############################
# ExportToolKits
# GetToolList
############################
@router.get("/list", response_model=list[ToolUserResponse])
async def get_tool_list(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "write")
return tools
############################
# ExportTools
############################
@router.get("/export", response_model=list[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
async def export_tools(user=Depends(get_admin_user)):
tools = Tools.get_tools()
return tools
############################
# CreateNewToolKit
# CreateNewTools
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
async def create_new_tools(
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -54,30 +86,30 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit is None:
tools = Tools.get_tool_by_id(form_data.id)
if tools is None:
try:
form_data.content = replace_imports(form_data.content)
toolkit_module, frontmatter = load_toolkit_module_by_id(
tools_module, frontmatter = load_tools_module_by_id(
form_data.id, content=form_data.content
)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
TOOLS[form_data.id] = tools_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
if toolkit:
return toolkit
if tools:
return tools
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
)
except Exception as e:
print(e)
@@ -93,16 +125,21 @@ async def create_new_toolkit(
############################
# GetToolkitById
# GetToolsById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
if tools:
if (
user.role == "admin"
or tools.user_id == user.id
or has_access(user.id, "read", tools.access_control)
):
return tools
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -111,26 +148,39 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
############################
# UpdateToolkitById
# UpdateToolsById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
async def update_tools_by_id(
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
tools = Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if tools.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try:
form_data.content = replace_imports(form_data.content)
toolkit_module, frontmatter = load_toolkit_module_by_id(
tools_module, frontmatter = load_tools_module_by_id(
id, content=form_data.content
)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
TOOLS[id] = tools_module
specs = get_tools_specs(TOOLS[id])
@@ -140,14 +190,14 @@ async def update_toolkit_by_id(
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
tools = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
if tools:
return tools
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
)
except Exception as e:
@@ -158,14 +208,28 @@ async def update_toolkit_by_id(
############################
# DeleteToolkitById
# DeleteToolsById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if tools.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
@@ -180,9 +244,9 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
if tools:
try:
valves = Tools.get_tool_valves_by_id(id)
return valves
@@ -204,19 +268,19 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_toolkit_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
tools = Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
tools_module = request.app.state.TOOLS[id]
else:
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "Valves"):
Valves = toolkit_module.Valves
if hasattr(tools_module, "Valves"):
Valves = tools_module.Valves
return Valves.schema()
return None
else:
@@ -232,19 +296,19 @@ async def get_toolkit_valves_spec_by_id(
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_toolkit_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
tools = Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
tools_module = request.app.state.TOOLS[id]
else:
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "Valves"):
Valves = toolkit_module.Valves
if hasattr(tools_module, "Valves"):
Valves = tools_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
@@ -276,9 +340,9 @@ async def update_toolkit_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
if tools:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
@@ -295,19 +359,19 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_toolkit_user_valves_spec_by_id(
async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
tools = Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
tools_module = request.app.state.TOOLS[id]
else:
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
if hasattr(tools_module, "UserValves"):
UserValves = tools_module.UserValves
return UserValves.schema()
return None
else:
@@ -318,20 +382,20 @@ async def get_toolkit_user_valves_spec_by_id(
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_toolkit_user_valves_by_id(
async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
tools = Tools.get_tool_by_id(id)
if toolkit:
if tools:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
tools_module = request.app.state.TOOLS[id]
else:
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
if hasattr(tools_module, "UserValves"):
UserValves = tools_module.UserValves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}

View File

@@ -31,21 +31,58 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
return Users.get_users(skip, limit)
############################
# User Groups
############################
@router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id)
############################
# User Permissions
############################
@router.get("/permissions/user")
@router.get("/permissions")
async def get_user_permissisions(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id)
############################
# User Default Permissions
############################
class WorkspacePermissions(BaseModel):
models: bool
knowledge: bool
prompts: bool
tools: bool
class ChatPermissions(BaseModel):
file_upload: bool
delete: bool
edit: bool
temporary: bool
class UserPermissions(BaseModel):
workspace: WorkspacePermissions
chat: ChatPermissions
@router.get("/default/permissions")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user")
@router.post("/default/permissions")
async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user)
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
):
request.app.state.config.USER_PERMISSIONS = form_data
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
return request.app.state.config.USER_PERMISSIONS

View File

@@ -63,7 +63,7 @@ def replace_imports(content):
return content
def load_toolkit_module_by_id(toolkit_id, content=None):
def load_tools_module_by_id(toolkit_id, content=None):
if content is None:
tool = Tools.get_tool_by_id(toolkit_id)