This commit is contained in:
Timothy Jaeryang Baek 2024-12-10 00:54:13 -08:00
parent f6bec8d9f3
commit d3d161f723
112 changed files with 1217 additions and 1165 deletions

View File

@ -10,7 +10,7 @@ from urllib.parse import urlparse
import chromadb
import requests
import yaml
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR,
@ -432,7 +432,10 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
OAUTH_ALLOWED_DOMAINS = PersistentConfig(
"OAUTH_ALLOWED_DOMAINS",
"oauth.allowed_domains",
[domain.strip() for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")],
[
domain.strip()
for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")
],
)

View File

@ -3,7 +3,7 @@ import logging
from contextlib import contextmanager
from typing import Any, Optional
from open_webui.apps.webui.internal.wrappers import register_connection
from open_webui.internal.wrappers import register_connection
from open_webui.env import (
OPEN_WEBUI_DIR,
DATABASE_URL,

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
from logging.config import fileConfig
from alembic import context
from open_webui.apps.webui.models.auths import Auth
from open_webui.models.auths import Auth
from open_webui.env import DATABASE_URL
from sqlalchemy import engine_from_config, pool

View File

@ -9,7 +9,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.apps.webui.internal.db
import open_webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.

View File

@ -11,8 +11,8 @@ from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
import open_webui.apps.webui.internal.db
from open_webui.apps.webui.internal.db import JSONField
import open_webui.internal.db
from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
# revision identifiers, used by Alembic.

View File

@ -2,8 +2,8 @@ import logging
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict

View File

@ -2,7 +2,7 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict

View File

@ -2,8 +2,8 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text

View File

@ -4,10 +4,10 @@ import time
from typing import Optional
import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict

View File

@ -4,11 +4,11 @@ import time
from typing import Optional
import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.models.files import FileMetadataResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict

View File

@ -2,7 +2,7 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text

View File

@ -2,10 +2,10 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict

View File

@ -1,8 +1,8 @@
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.internal.db import Base, get_db
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -3,7 +3,7 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS

View File

@ -2,8 +2,8 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -1,8 +1,8 @@
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.chats import Chats
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text

View File

@ -40,7 +40,7 @@ class PgvectorClient:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.apps.webui.internal.db import Session
from open_webui.internal.db import Session
self.session = Session
else:

View File

@ -25,11 +25,10 @@ from open_webui.config import (
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
CACHE_DIR,
CORS_ALLOW_ORIGIN,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
CACHE_DIR,
AppConfig,
)
@ -55,44 +54,6 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.faster_whisper_model = None
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app.state.speech_synthesiser = None
app.state.speech_speaker_embeddings_dataset = None
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"

View File

@ -5,7 +5,7 @@ import datetime
import logging
from aiohttp import ClientSession
from open_webui.apps.webui.models.auths import (
from open_webui.models.auths import (
AddUserForm,
ApiKey,
Auths,
@ -18,7 +18,7 @@ from open_webui.apps.webui.models.auths import (
UpdateProfileForm,
UserResponse,
)
from open_webui.apps.webui.models.users import Users
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (

View File

@ -0,0 +1,411 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel
router = APIRouter()
@app.post("/api/chat/completions")
async def generate_chat_completions(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: bool = False,
):
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
model_list = request.state.models
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = models[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if model.get("arena"):
if not has_access(
user.id,
type="read",
access_control=model.get("info", {})
.get("meta", {})
.get("access_control", {}),
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
else:
model_info = Models.get_model_by_id(model_id)
if not model_info:
raise HTTPException(
status_code=404,
detail="Model not found",
)
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completions(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(
await generate_chat_completions(form_data, user, bypass_filter=True)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
form_data = GenerateChatCompletionForm(**form_data)
response = await generate_ollama_chat_completion(
form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.stream:
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
form_data, user=user, bypass_filter=bypass_filter
)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
data = form_data
model_id = data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = models[model_id]
sorted_filters = get_sorted_filters(model_id, models)
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except Exception:
pass
else:
pass
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
@app.post("/api/chat/actions/{action_id}")
async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)):
if "." in action_id:
action_id, sub_action_id = action_id.split(".")
else:
sub_action_id = None
action = Functions.get_function_by_id(action_id)
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Action not found",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
data = form_data
model_id = data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = models[model_id]
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": sub_action_id if sub_action_id is not None else action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data

View File

@ -2,15 +2,15 @@ import json
import logging
from typing import Optional
from open_webui.apps.webui.models.chats import (
from open_webui.models.chats import (
ChatForm,
ChatImportForm,
ChatResponse,
Chats,
ChatTitleIdResponse,
)
from open_webui.apps.webui.models.tags import TagModel, Tags
from open_webui.apps.webui.models.folders import Folders
from open_webui.models.tags import TagModel, Tags
from open_webui.models.folders import Folders
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES

View File

@ -2,8 +2,8 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from open_webui.apps.webui.models.users import Users, UserModel
from open_webui.apps.webui.models.feedbacks import (
from open_webui.models.users import Users, UserModel
from open_webui.models.feedbacks import (
FeedbackModel,
FeedbackResponse,
FeedbackForm,

View File

@ -8,13 +8,13 @@ import mimetypes
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import (
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from backend.open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS

View File

@ -8,12 +8,12 @@ from pydantic import BaseModel
import mimetypes
from open_webui.apps.webui.models.folders import (
from open_webui.models.folders import (
FolderForm,
FolderModel,
Folders,
)
from open_webui.apps.webui.models.chats import Chats
from open_webui.models.chats import Chats
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS

View File

@ -2,13 +2,13 @@ import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.functions import (
from open_webui.models.functions import (
FunctionForm,
FunctionModel,
FunctionResponse,
Functions,
)
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
from backend.open_webui.utils.plugin import load_function_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status

View File

@ -2,7 +2,7 @@ import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.groups import (
from open_webui.models.groups import (
Groups,
GroupForm,
GroupUpdateForm,

View File

@ -9,31 +9,14 @@ from pathlib import Path
from typing import Optional
import requests
from open_webui.apps.images.utils.comfyui import (
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
from open_webui.config import (
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
AUTOMATIC1111_CFG_SCALE,
AUTOMATIC1111_SAMPLER,
AUTOMATIC1111_SCHEDULER,
CACHE_DIR,
COMFYUI_BASE_URL,
COMFYUI_WORKFLOW,
COMFYUI_WORKFLOW_NODES,
CORS_ALLOW_ORIGIN,
ENABLE_IMAGE_GENERATION,
IMAGE_GENERATION_ENGINE,
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
AppConfig,
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
@ -54,36 +37,6 @@ app = FastAPI(
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):

View File

@ -4,15 +4,15 @@ from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging
from open_webui.apps.webui.models.knowledge import (
from open_webui.models.knowledge import (
Knowledges,
KnowledgeForm,
KnowledgeResponse,
KnowledgeUserResponse,
)
from open_webui.apps.webui.models.files import Files, FileModel
from open_webui.models.files import Files, FileModel
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from backend.open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.constants import ERROR_MESSAGES

View File

@ -3,7 +3,7 @@ from pydantic import BaseModel
import logging
from typing import Optional
from open_webui.apps.webui.models.memories import Memories, MemoryModel
from open_webui.models.memories import Memories, MemoryModel
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user
from open_webui.env import SRC_LOG_LEVELS

View File

@ -1,6 +1,6 @@
from typing import Optional
from open_webui.apps.webui.models.models import (
from open_webui.models.models import (
ModelForm,
ModelModel,
ModelResponse,

View File

@ -12,15 +12,22 @@ import aiohttp
from aiocache import cached
import requests
from open_webui.apps.webui.models.models import Models
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from starlette.background import BackgroundTask
from open_webui.models.models import Models
from open_webui.config import (
CORS_ALLOW_ORIGIN,
ENABLE_OLLAMA_API,
OLLAMA_BASE_URLS,
OLLAMA_API_CONFIGS,
UPLOAD_DIR,
AppConfig,
)
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
@ -30,11 +37,6 @@ from open_webui.env import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from starlette.background import BackgroundTask
from open_webui.utils.misc import (
@ -52,27 +54,6 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization.

View File

@ -10,7 +10,7 @@ from aiocache import cached
import requests
from open_webui.apps.webui.models.models import Models
from open_webui.models.models import Models
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
@ -48,29 +48,6 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {
@ -91,7 +68,6 @@ class OpenAIConfigForm(BaseModel):
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS

View File

@ -3,7 +3,7 @@ from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
@ -14,86 +14,328 @@ from open_webui.utils.auth import get_admin_user
router = APIRouter()
@router.get("/gravatar")
async def get_gravatar(
email: str,
##################################
#
# Pipelines Endpoints
#
##################################
# TODO: Refactor pipelines API endpoints below into a separate file
@app.get("/api/pipelines/list")
async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models_responses()
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
urlIdxs = [
idx
for idx, response in enumerate(responses)
if response is not None and "pipelines" in response
]
return {
"data": [
{
"url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
"idx": urlIdx,
}
for urlIdx in urlIdxs
]
}
@app.post("/api/pipelines/upload")
async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
):
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
@router.post("/markdown")
async def get_html_from_markdown(
form_data: MarkdownForm,
):
return {"html": markdown.markdown(form_data.md)}
class ChatForm(BaseModel):
title: str
messages: list[dict]
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatTitleMessagesForm,
):
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
return Response(
content=pdf_bytes,
media_type="application/pdf",
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=str(e))
@router.get("/db/download")
async def download_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from open_webui.apps.webui.internal.db import engine
if engine.name != "sqlite":
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
if not (file.filename and file.filename.endswith(".py")):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
detail="Only Python (.py) files are allowed.",
)
return FileResponse(
engine.url.database,
media_type="application/octet-stream",
filename="webui.db",
)
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
r = None
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
with open(file_path, "rb") as f:
files = {"file": f}
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None:
status_code = r.status_code
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=status_code,
detail=detail,
)
finally:
# Ensure the file is deleted after the upload is completed or on failure
if os.path.exists(file_path):
os.remove(file_path)
@router.get("/litellm/config")
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/litellm/config.yaml",
media_type="application/octet-stream",
filename="config.yaml",
)
class AddPipelineForm(BaseModel):
url: str
urlIdx: int
@app.post("/api/pipelines/add")
async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
r = None
try:
urlIdx = form_data.urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
class DeletePipelineForm(BaseModel):
id: str
urlIdx: int
@app.delete("/api/pipelines/delete")
async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
r = None
try:
urlIdx = form_data.urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
r = requests.delete(
f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
@app.get("/api/pipelines")
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
r = requests.get(f"{url}/pipelines", headers=headers)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
@app.post("/api/pipelines/{pipeline_id}/valves/update")
async def update_pipeline_valves(
urlIdx: Optional[int],
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
):
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{pipeline_id}/valves/update",
headers=headers,
json={**form_data},
)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except Exception:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from open_webui.apps.webui.models.prompts import (
from open_webui.models.prompts import (
PromptForm,
PromptUserResponse,
PromptModel,

View File

@ -18,7 +18,7 @@ import tiktoken
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.knowledge import Knowledges
from open_webui.models.knowledge import Knowledges
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
# Document loaders
@ -43,7 +43,7 @@ from open_webui.apps.retrieval.web.tavily import search_tavily
from open_webui.apps.retrieval.web.bing import search_bing
from open_webui.apps.retrieval.utils import (
from backend.open_webui.retrieval.utils import (
get_embedding_function,
get_model_path,
query_collection,
@ -52,7 +52,7 @@ from open_webui.apps.retrieval.utils import (
query_doc_with_hybrid_search,
)
from open_webui.apps.webui.models.files import Files
from open_webui.models.files import Files
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
KAGI_SEARCH_API_KEY,

View File

@ -2,6 +2,8 @@ from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
from pydantic import BaseModel
from starlette.responses import FileResponse
from typing import Optional
import logging
from open_webui.utils.task import (
title_generation_template,
@ -12,6 +14,17 @@ from open_webui.utils.task import (
moa_response_generation_template,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.config import (
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
@ -197,7 +210,9 @@ Artificial Intelligence in Healthcare
@router.post("/tags/completions")
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_TAGS_GENERATION:
return JSONResponse(

View File

@ -1,14 +1,14 @@
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.tools import (
from open_webui.models.tools import (
ToolForm,
ToolModel,
ToolResponse,
ToolUserResponse,
Tools,
)
from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
from backend.open_webui.utils.plugin import load_tools_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status

View File

@ -1,9 +1,9 @@
import logging
from typing import Optional
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.chats import Chats
from open_webui.apps.webui.models.users import (
from open_webui.models.auths import Auths
from open_webui.models.chats import Chats
from open_webui.models.users import (
UserModel,
UserRoleUpdateForm,
Users,

View File

@ -1,7 +1,7 @@
import black
import markdown
from open_webui.apps.webui.models.chats import ChatTitleMessagesForm
from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
@ -76,7 +76,7 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from open_webui.apps.webui.internal.db import engine
from open_webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException(

View File

@ -5,9 +5,9 @@ import time
from typing import AsyncGenerator, Generator, Iterator
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.routers import (
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.routers import (
auths,
chats,
folders,
@ -24,7 +24,7 @@ from open_webui.apps.webui.routers import (
users,
utils,
)
from open_webui.apps.webui.utils import load_function_module_by_id
from backend.open_webui.utils.plugin import load_function_module_by_id
from open_webui.config import (
ADMIN_EMAIL,
CORS_ALLOW_ORIGIN,

View File

@ -6,7 +6,7 @@ import logging
import sys
import time
from open_webui.apps.webui.models.users import Users
from open_webui.models.users import Users
from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER,

View File

@ -7,8 +7,8 @@ class TestAuths(AbstractPostgresTest):
def setup_class(cls):
super().setup_class()
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users
from open_webui.models.auths import Auths
from open_webui.models.users import Users
cls.users = Users
cls.auths = Auths

Some files were not shown because too many files have changed in this diff Show More