mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into feat/model-config
This commit is contained in:
@@ -43,6 +43,7 @@ from utils.utils import (
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
OLLAMA_BASE_URLS,
|
||||
ENABLE_OLLAMA_API,
|
||||
ENABLE_MODEL_FILTER,
|
||||
MODEL_FILTER_LIST,
|
||||
UPLOAD_DIR,
|
||||
@@ -68,6 +69,8 @@ app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.MODEL_CONFIG = Models.get_all_models()
|
||||
|
||||
|
||||
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.MODELS = {}
|
||||
|
||||
@@ -97,6 +100,21 @@ async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
|
||||
|
||||
|
||||
class OllamaConfigForm(BaseModel):
|
||||
enable_ollama_api: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
|
||||
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
|
||||
|
||||
|
||||
@app.get("/urls")
|
||||
async def get_ollama_api_urls(user=Depends(get_admin_user)):
|
||||
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
||||
@@ -157,17 +175,24 @@ def merge_models_lists(model_lists):
|
||||
|
||||
async def get_all_models():
|
||||
log.info("get_all_models()")
|
||||
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
map(
|
||||
lambda response: (response["models"] if response else None),
|
||||
responses,
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
tasks = [
|
||||
fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
map(
|
||||
lambda response: response["models"] if response else None, responses
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
for model in models["models"]:
|
||||
add_custom_info_to_model(model)
|
||||
|
||||
|
||||
@@ -316,6 +316,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
idx = 0
|
||||
pipeline = False
|
||||
|
||||
body = await request.body()
|
||||
# TODO: Remove below after gpt-4-vision fix from Open AI
|
||||
@@ -324,7 +325,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
|
||||
idx = app.state.MODELS[body.get("model")]["urlIdx"]
|
||||
model = app.state.MODELS[body.get("model")]
|
||||
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model:
|
||||
pipeline = model.get("pipeline")
|
||||
|
||||
if pipeline:
|
||||
body["user"] = {"name": user.name, "id": user.id}
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
from peewee import *
|
||||
from peewee_migrate import Router
|
||||
from playhouse.db_url import connect
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
|
||||
import os
|
||||
import logging
|
||||
|
||||
@@ -30,6 +30,8 @@ else:
|
||||
|
||||
DB = connect(DATABASE_URL)
|
||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
|
||||
router = Router(
|
||||
DB, migrate_dir=BACKEND_DIR / "apps" / "web" / "internal" / "migrations", logger=log
|
||||
)
|
||||
router.run()
|
||||
DB.connect(reuse_if_open=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import importlib.metadata
|
||||
import pkgutil
|
||||
import chromadb
|
||||
from chromadb import Settings
|
||||
from base64 import b64encode
|
||||
@@ -22,10 +24,13 @@ from constants import ERROR_MESSAGES
|
||||
# Load .env file
|
||||
####################################
|
||||
|
||||
BACKEND_DIR = Path(__file__).parent # the path containing this file
|
||||
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
load_dotenv(find_dotenv("../.env"))
|
||||
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
|
||||
except ImportError:
|
||||
print("dotenv not installed, skipping...")
|
||||
|
||||
@@ -87,10 +92,12 @@ WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
ENV = os.environ.get("ENV", "dev")
|
||||
|
||||
try:
|
||||
with open(f"../package.json", "r") as f:
|
||||
PACKAGE_DATA = json.load(f)
|
||||
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
|
||||
except:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
try:
|
||||
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
|
||||
VERSION = PACKAGE_DATA["version"]
|
||||
|
||||
@@ -115,10 +122,10 @@ def parse_section(section):
|
||||
|
||||
|
||||
try:
|
||||
with open("../CHANGELOG.md", "r") as file:
|
||||
changelog_content = file.read()
|
||||
changelog_content = (BASE_DIR / "CHANGELOG.md").read_text()
|
||||
except:
|
||||
changelog_content = ""
|
||||
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
|
||||
|
||||
|
||||
# Convert markdown content to HTML
|
||||
html_content = markdown.markdown(changelog_content)
|
||||
@@ -164,12 +171,11 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
|
||||
# DATA/FRONTEND BUILD DIR
|
||||
####################################
|
||||
|
||||
DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve())
|
||||
FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
|
||||
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
|
||||
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
|
||||
|
||||
try:
|
||||
with open(f"{DATA_DIR}/config.json", "r") as f:
|
||||
CONFIG_DATA = json.load(f)
|
||||
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
|
||||
except:
|
||||
CONFIG_DATA = {}
|
||||
|
||||
@@ -279,11 +285,11 @@ JWT_EXPIRES_IN = PersistentConfig(
|
||||
# Static DIR
|
||||
####################################
|
||||
|
||||
STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve())
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
|
||||
|
||||
frontend_favicon = f"{FRONTEND_BUILD_DIR}/favicon.png"
|
||||
if os.path.exists(frontend_favicon):
|
||||
shutil.copyfile(frontend_favicon, f"{STATIC_DIR}/favicon.png")
|
||||
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
|
||||
if frontend_favicon.exists():
|
||||
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
||||
else:
|
||||
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
|
||||
|
||||
@@ -378,6 +384,13 @@ if not os.path.exists(LITELLM_CONFIG_PATH):
|
||||
# OLLAMA_BASE_URL
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_OLLAMA_API = PersistentConfig(
|
||||
"ENABLE_OLLAMA_API",
|
||||
"ollama.enable",
|
||||
os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true",
|
||||
)
|
||||
|
||||
OLLAMA_API_BASE_URL = os.environ.get(
|
||||
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import sys
|
||||
import logging
|
||||
import aiohttp
|
||||
import requests
|
||||
import mimetypes
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -437,6 +438,7 @@ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
||||
|
||||
if os.path.exists(FRONTEND_BUILD_DIR):
|
||||
mimetypes.add_type("text/javascript", ".js")
|
||||
app.mount(
|
||||
"/",
|
||||
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
|
||||
|
||||
60
backend/open_webui/__init__.py
Normal file
60
backend/open_webui/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import uvicorn
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
KEY_FILE = Path.cwd() / ".webui_secret_key"
|
||||
if (frontend_build_dir := Path(__file__).parent / "frontend").exists():
|
||||
os.environ["FRONTEND_BUILD_DIR"] = str(frontend_build_dir)
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
):
|
||||
if os.getenv("WEBUI_SECRET_KEY") is None:
|
||||
typer.echo(
|
||||
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
|
||||
)
|
||||
if not KEY_FILE.exists():
|
||||
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
|
||||
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
|
||||
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
|
||||
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
|
||||
|
||||
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
|
||||
typer.echo(
|
||||
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
||||
)
|
||||
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
|
||||
os.environ["LD_LIBRARY_PATH"] = ":".join(
|
||||
LD_LIBRARY_PATH
|
||||
+ [
|
||||
"/usr/local/lib/python3.11/site-packages/torch/lib",
|
||||
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
|
||||
]
|
||||
)
|
||||
import main # we need set environment variables before importing main
|
||||
|
||||
uvicorn.run(main.app, host=host, port=port, forwarded_allow_ips="*")
|
||||
|
||||
|
||||
@app.command()
|
||||
def dev(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
reload: bool = True,
|
||||
):
|
||||
uvicorn.run(
|
||||
"main:app", host=host, port=port, reload=reload, forwarded_allow_ips="*"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
43
backend/space/litellm_config.yaml
Normal file
43
backend/space/litellm_config.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
litellm_settings:
|
||||
drop_params: true
|
||||
model_list:
|
||||
- model_name: 'HuggingFace: Mistral: Mistral 7B Instruct v0.1'
|
||||
litellm_params:
|
||||
model: huggingface/mistralai/Mistral-7B-Instruct-v0.1
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 1024
|
||||
- model_name: 'HuggingFace: Mistral: Mistral 7B Instruct v0.2'
|
||||
litellm_params:
|
||||
model: huggingface/mistralai/Mistral-7B-Instruct-v0.2
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 1024
|
||||
- model_name: 'HuggingFace: Meta: Llama 3 8B Instruct'
|
||||
litellm_params:
|
||||
model: huggingface/meta-llama/Meta-Llama-3-8B-Instruct
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 2047
|
||||
- model_name: 'HuggingFace: Mistral: Mixtral 8x7B Instruct v0.1'
|
||||
litellm_params:
|
||||
model: huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 8192
|
||||
- model_name: 'HuggingFace: Microsoft: Phi-3 Mini-4K-Instruct'
|
||||
litellm_params:
|
||||
model: huggingface/microsoft/Phi-3-mini-4k-instruct
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 1024
|
||||
- model_name: 'HuggingFace: Google: Gemma 7B 1.1'
|
||||
litellm_params:
|
||||
model: huggingface/google/gemma-1.1-7b-it
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 1024
|
||||
- model_name: 'HuggingFace: Yi-1.5 34B Chat'
|
||||
litellm_params:
|
||||
model: huggingface/01-ai/Yi-1.5-34B-Chat
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 1024
|
||||
- model_name: 'HuggingFace: Nous Research: Nous Hermes 2 Mixtral 8x7B DPO'
|
||||
litellm_params:
|
||||
model: huggingface/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
|
||||
api_key: os.environ/HF_TOKEN
|
||||
max_tokens: 2048
|
||||
@@ -30,4 +30,34 @@ if [ "$USE_CUDA_DOCKER" = "true" ]; then
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
|
||||
fi
|
||||
|
||||
|
||||
# Check if SPACE_ID is set, if so, configure for space
|
||||
if [ -n "$SPACE_ID" ]; then
|
||||
echo "Configuring for HuggingFace Space deployment"
|
||||
|
||||
# Copy litellm_config.yaml with specified ownership
|
||||
echo "Copying litellm_config.yaml to the desired location with specified ownership..."
|
||||
cp -f ./space/litellm_config.yaml ./data/litellm/config.yaml
|
||||
|
||||
if [ -n "$ADMIN_USER_EMAIL" ] && [ -n "$ADMIN_USER_PASSWORD" ]; then
|
||||
echo "Admin user configured, creating"
|
||||
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' &
|
||||
webui_pid=$!
|
||||
echo "Waiting for webui to start..."
|
||||
while ! curl -s http://localhost:8080/health > /dev/null; do
|
||||
sleep 1
|
||||
done
|
||||
echo "Creating admin user..."
|
||||
curl \
|
||||
-X POST "http://localhost:8080/api/v1/auths/signup" \
|
||||
-H "accept: application/json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{ \"email\": \"${ADMIN_USER_EMAIL}\", \"password\": \"${ADMIN_USER_PASSWORD}\", \"name\": \"Admin\" }"
|
||||
echo "Shutting down webui..."
|
||||
kill $webui_pid
|
||||
fi
|
||||
|
||||
export WEBUI_URL=${SPACE_HOST}
|
||||
fi
|
||||
|
||||
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*'
|
||||
|
||||
Reference in New Issue
Block a user