Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search

This commit is contained in:
Jun Siang Cheah
2024-05-26 11:31:07 +01:00
103 changed files with 1279 additions and 1709 deletions

View File

@@ -1,388 +0,0 @@
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
import time
import requests
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from apps.web.models.models import Models
from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS
from constants import MESSAGES
import os
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
DATA_DIR,
LITELLM_PROXY_PORT,
LITELLM_PROXY_HOST,
)
import warnings
warnings.simplefilter("ignore")
from litellm.utils import get_llm_provider
import asyncio
import subprocess
import yaml
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.MODEL_CONFIG = Models.get_all_models()
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference
background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command):
global background_process
log.info("run_background_process")
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
background_process = process
log.info("Subprocess started successfully.")
# Capture STDERR for debugging purposes
stderr_output = await process.stderr.read()
stderr_text = stderr_output.decode().strip()
if stderr_text:
log.info(f"Subprocess STDERR: {stderr_text}")
# log.info output line by line
async for line in process.stdout:
log.info(line.decode().strip())
# Wait for the process to finish
returncode = await process.wait()
log.info(f"Subprocess exited with return code {returncode}")
except Exception as e:
log.error(f"Failed to start subprocess: {e}")
raise # Optionally re-raise the exception if you want it to propagate
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = [
"litellm",
"--port",
str(LITELLM_PROXY_PORT),
"--host",
LITELLM_PROXY_HOST,
"--telemetry",
"False",
"--config",
LITELLM_CONFIG_DIR,
]
await run_background_process(command)
async def shutdown_litellm_background():
log.info("shutdown_litellm_background")
global background_process
if background_process:
background_process.terminate()
await background_process.wait() # Ensure the process has terminated
log.info("Subprocess terminated")
background_process = None
@app.get("/")
async def get_status():
return {"status": True}
async def restart_litellm():
"""
Endpoint to restart the litellm background service.
"""
log.info("Requested restart of litellm service.")
try:
# Shut down the existing process if it is running
await shutdown_litellm_background()
log.info("litellm service shutdown complete.")
# Restart the background service
asyncio.create_task(start_litellm_background())
log.info("litellm service restart complete.")
return {
"status": "success",
"message": "litellm service restarted successfully.",
}
except Exception as e:
log.info(f"Error restarting litellm service: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
return await restart_litellm()
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return app.state.CONFIG
class LiteLLMConfigForm(BaseModel):
general_settings: Optional[dict] = None
litellm_settings: Optional[dict] = None
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
app.state.CONFIG = form_data.model_dump(exclude_none=True)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return app.state.CONFIG
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
if app.state.ENABLE:
while not background_process:
await asyncio.sleep(0.1)
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.ENABLE_MODEL_FILTER:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"custom_info": next(
(
item
for item in app.state.MODEL_CONFIG
if item.id == model["model_name"]
),
None,
),
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
else:
return {
"data": [],
"object": "list",
}
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):
id: str
@app.post("/model/delete")
async def delete_model_from_config(
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
app.state.CONFIG["model_list"] = [
model
for model in app.state.CONFIG["model_list"]
if model["model_name"] != form_data.id
]
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = f"http://localhost:{LITELLM_PROXY_PORT}"
target_url = f"{url}/{path}"
headers = {}
# headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
response_data = r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)

View File

@@ -29,8 +29,8 @@ import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from apps.web.models.models import Models
from apps.web.models.users import Users
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
@@ -306,6 +306,9 @@ async def pull_model(
r = None
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request():
nonlocal url
nonlocal r
@@ -333,7 +336,7 @@ async def pull_model(
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=form_data.model_dump_json(exclude_none=True).encode(),
data=json.dumps(payload),
stream=True,
)

View File

@@ -10,8 +10,8 @@ import logging
from pydantic import BaseModel
from apps.web.models.models import Models
from apps.web.models.users import Users
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,

View File

@@ -46,7 +46,7 @@ import json
import sentence_transformers
from apps.web.models.documents import (
from apps.webui.models.documents import (
Documents,
DocumentForm,
DocumentResponse,

View File

@@ -1,144 +0,0 @@
################################################################################
# DEPRECATION NOTICE #
# #
# This file has been deprecated since version 0.2.0. #
# #
################################################################################
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
import json
####################
# Modelfile DB Schema
####################
class Modelfile(Model):
tag_name = CharField(unique=True)
user_id = CharField()
modelfile = TextField()
timestamp = BigIntegerField()
class Meta:
database = DB
class ModelfileModel(BaseModel):
tag_name: str
user_id: str
modelfile: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class ModelfileForm(BaseModel):
modelfile: dict
class ModelfileTagNameForm(BaseModel):
tag_name: str
class ModelfileUpdateForm(ModelfileForm, ModelfileTagNameForm):
pass
class ModelfileResponse(BaseModel):
tag_name: str
user_id: str
modelfile: dict
timestamp: int # timestamp in epoch
class ModelfilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
"user_id": user_id,
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
}
)
try:
result = Modelfile.create(**modelfile.model_dump())
if result:
return modelfile
else:
return None
except:
return None
else:
return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),
timestamp=int(time.time()),
).where(Modelfile.tag_name == tag_name)
query.execute()
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def delete_modelfile_by_tag_name(self, tag_name: str) -> bool:
try:
query = Modelfile.delete().where((Modelfile.tag_name == tag_name))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Modelfiles = ModelfilesTable(DB)

View File

@@ -31,7 +31,9 @@ else:
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(
DB, migrate_dir=BACKEND_DIR / "apps" / "web" / "internal" / "migrations", logger=log
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
)
router.run()
DB.connect(reuse_if_open=True)

View File

@@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.

View File

@@ -1,7 +1,7 @@
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import (
from apps.webui.routers import (
auths,
users,
chats,

View File

@@ -5,10 +5,10 @@ import uuid
import logging
from peewee import *
from apps.web.models.users import UserModel, Users
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
from config import SRC_LOG_LEVELS

View File

@@ -7,7 +7,7 @@ import json
import uuid
import time
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
####################
# Chat DB Schema
@@ -191,6 +191,20 @@ class ChatTable:
except:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
chats = self.get_chats_by_user_id(user_id)
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
return True
except:
return False
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
@@ -205,17 +219,31 @@ class ChatTable:
]
def get_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
self,
user_id: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 50,
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
if include_archived:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
else:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50

View File

@@ -8,7 +8,7 @@ import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
import json

View File

@@ -3,8 +3,8 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
from apps.webui.internal.db import DB
from apps.webui.models.chats import Chats
import time
import uuid

View File

@@ -8,7 +8,7 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField
from apps.webui.internal.db import DB, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS

View File

@@ -7,7 +7,7 @@ import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
import json

View File

@@ -8,7 +8,7 @@ import uuid
import time
import logging
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
from config import SRC_LOG_LEVELS

View File

@@ -5,8 +5,8 @@ from typing import List, Union, Optional
import time
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
from apps.webui.internal.db import DB
from apps.webui.models.chats import Chats
####################
# User DB Schema

View File

@@ -10,7 +10,7 @@ import uuid
import csv
from apps.web.models.auths import (
from apps.webui.models.auths import (
SigninForm,
SignupForm,
AddUserForm,
@@ -21,7 +21,7 @@ from apps.web.models.auths import (
Auths,
ApiKey,
)
from apps.web.models.users import Users
from apps.webui.models.users import Users
from utils.utils import (
get_password_hash,

View File

@@ -7,8 +7,8 @@ from pydantic import BaseModel
import json
import logging
from apps.web.models.users import Users
from apps.web.models.chats import (
from apps.webui.models.users import Users
from apps.webui.models.chats import (
ChatModel,
ChatResponse,
ChatTitleForm,
@@ -18,7 +18,7 @@ from apps.web.models.chats import (
)
from apps.web.models.tags import (
from apps.webui.models.tags import (
TagModel,
ChatIdTagModel,
ChatIdTagForm,
@@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user_id, skip, limit)
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
)
############################
# GetArchivedChats
# CreateNewChat
############################
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
@@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
############################
# CreateNewChat
# GetArchivedChats
############################
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
log.exception(e)
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# ArchiveAllChats
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel
import time
import uuid
from apps.web.models.users import Users
from apps.webui.models.users import Users
from utils.utils import (
get_password_hash,

View File

@@ -6,7 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.documents import (
from apps.webui.models.documents import (
Documents,
DocumentForm,
DocumentUpdateForm,

View File

@@ -7,7 +7,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.web.models.memories import Memories, MemoryModel
from apps.webui.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
from constants import ERROR_MESSAGES

View File

@@ -5,7 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
@@ -53,7 +53,7 @@ async def add_new_model(
############################
@router.get("/{id}", response_model=Optional[ModelModel])
@router.get("/", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
@@ -71,7 +71,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.post("/{id}/update", response_model=Optional[ModelModel])
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
@@ -102,7 +102,7 @@ async def update_model_by_id(
############################
@router.delete("/{id}/delete", response_model=bool)
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result

View File

@@ -6,7 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES

View File

@@ -9,9 +9,9 @@ import time
import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from apps.web.models.chats import Chats
from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url

View File

@@ -27,6 +27,8 @@ from constants import ERROR_MESSAGES
BACKEND_DIR = Path(__file__).parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR)
try:
from dotenv import load_dotenv, find_dotenv
@@ -56,7 +58,6 @@ log_sources = [
"CONFIG",
"DB",
"IMAGES",
"LITELLM",
"MAIN",
"MODELS",
"OLLAMA",
@@ -122,7 +123,10 @@ def parse_section(section):
try:
changelog_content = (BASE_DIR / "CHANGELOG.md").read_text()
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
@@ -374,10 +378,10 @@ def create_config_file(file_path):
LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
if not os.path.exists(LITELLM_CONFIG_PATH):
log.info("Config file doesn't exist. Creating...")
create_config_file(LITELLM_CONFIG_PATH)
log.info("Config file created successfully.")
# if not os.path.exists(LITELLM_CONFIG_PATH):
# log.info("Config file doesn't exist. Creating...")
# create_config_file(LITELLM_CONFIG_PATH)
# log.info("Config file created successfully.")
####################################
@@ -845,18 +849,6 @@ AUDIO_OPENAI_API_VOICE = PersistentConfig(
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
)
####################################
# LiteLLM
####################################
ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true"
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database

View File

@@ -22,23 +22,16 @@ from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
from apps.litellm.main import (
app as litellm_app,
start_litellm_background,
shutdown_litellm_background,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from apps.webui.main import app as webui_app
import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.web.models.models import Models, ModelModel
from apps.webui.models.models import Models, ModelModel
from utils.utils import get_admin_user, get_verified_user
from apps.rag.utils import rag_messages
@@ -55,7 +48,6 @@ from config import (
STATIC_DIR,
ENABLE_OPENAI_API,
ENABLE_OLLAMA_API,
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
@@ -101,11 +93,7 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
yield
if ENABLE_LITELLM:
await shutdown_litellm_background()
app = FastAPI(
@@ -263,9 +251,6 @@ async def update_embedding_function(request: Request, call_next):
return response
# TODO: Deprecate LiteLLM
app.mount("/litellm/api", litellm_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
@@ -373,13 +358,14 @@ async def get_app_config():
"name": WEBUI_NAME,
"version": VERSION,
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"default_locale": default_locale,
"images": images_app.state.config.ENABLED,
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
"websearch": RAG_WEB_SEARCH_ENABLED,
"enable_websearch": RAG_WEB_SEARCH_ENABLED,
}
@@ -403,15 +389,6 @@ async def update_model_filter_config(
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models
ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
return {
"enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": app.state.config.MODEL_FILTER_LIST,
@@ -432,7 +409,6 @@ class UrlForm(BaseModel):
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {

View File

@@ -18,8 +18,6 @@ psycopg2-binary==2.9.9
PyMySQL==1.1.1
bcrypt==4.1.3
litellm[proxy]==1.37.20
boto3==1.34.110
argon2-cffi==23.1.0

View File

@@ -1,43 +0,0 @@
litellm_settings:
drop_params: true
model_list:
- model_name: 'HuggingFace: Mistral: Mistral 7B Instruct v0.1'
litellm_params:
model: huggingface/mistralai/Mistral-7B-Instruct-v0.1
api_key: os.environ/HF_TOKEN
max_tokens: 1024
- model_name: 'HuggingFace: Mistral: Mistral 7B Instruct v0.2'
litellm_params:
model: huggingface/mistralai/Mistral-7B-Instruct-v0.2
api_key: os.environ/HF_TOKEN
max_tokens: 1024
- model_name: 'HuggingFace: Meta: Llama 3 8B Instruct'
litellm_params:
model: huggingface/meta-llama/Meta-Llama-3-8B-Instruct
api_key: os.environ/HF_TOKEN
max_tokens: 2047
- model_name: 'HuggingFace: Mistral: Mixtral 8x7B Instruct v0.1'
litellm_params:
model: huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1
api_key: os.environ/HF_TOKEN
max_tokens: 8192
- model_name: 'HuggingFace: Microsoft: Phi-3 Mini-4K-Instruct'
litellm_params:
model: huggingface/microsoft/Phi-3-mini-4k-instruct
api_key: os.environ/HF_TOKEN
max_tokens: 1024
- model_name: 'HuggingFace: Google: Gemma 7B 1.1'
litellm_params:
model: huggingface/google/gemma-1.1-7b-it
api_key: os.environ/HF_TOKEN
max_tokens: 1024
- model_name: 'HuggingFace: Yi-1.5 34B Chat'
litellm_params:
model: huggingface/01-ai/Yi-1.5-34B-Chat
api_key: os.environ/HF_TOKEN
max_tokens: 1024
- model_name: 'HuggingFace: Nous Research: Nous Hermes 2 Mixtral 8x7B DPO'
litellm_params:
model: huggingface/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
api_key: os.environ/HF_TOKEN
max_tokens: 2048

View File

@@ -34,11 +34,6 @@ fi
# Check if SPACE_ID is set, if so, configure for space
if [ -n "$SPACE_ID" ]; then
echo "Configuring for HuggingFace Space deployment"
# Copy litellm_config.yaml with specified ownership
echo "Copying litellm_config.yaml to the desired location with specified ownership..."
cp -f ./space/litellm_config.yaml ./data/litellm/config.yaml
if [ -n "$ADMIN_USER_EMAIL" ] && [ -n "$ADMIN_USER_PASSWORD" ]; then
echo "Admin user configured, creating"
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*' &

View File

@@ -1,4 +1,4 @@
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):

View File

@@ -1,7 +1,7 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from apps.webui.models.users import Users
from pydantic import BaseModel
from typing import Union, Optional