Merge branch 'websearch' into feat/backend-web-search

This commit is contained in:
Timothy Jaeryang Baek
2024-05-26 23:40:05 -07:00
committed by GitHub
77 changed files with 1338 additions and 345 deletions

View File

@@ -198,7 +198,7 @@ async def fetch_url(url, key):
def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}")
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
@@ -237,7 +237,7 @@ async def get_all_models():
]
responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}")
log.debug(f"get_all_models:responses() {responses}")
models = {
"data": merge_models_lists(
@@ -254,7 +254,7 @@ async def get_all_models():
)
}
log.info(f"models: {models}")
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models

View File

@@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "settings")

View File

@@ -13,7 +13,7 @@ from apps.webui.routers import (
utils,
)
from config import (
WEBUI_VERSION,
WEBUI_BUILD_HASH,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
@@ -23,7 +23,9 @@ from config import (
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
)
app = FastAPI()
@@ -40,7 +42,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER

View File

@@ -1,11 +1,11 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB
from apps.webui.internal.db import DB, JSONField
from apps.webui.models.chats import Chats
####################
@@ -25,11 +25,18 @@ class User(Model):
created_at = BigIntegerField()
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
class Meta:
database = DB
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class UserModel(BaseModel):
id: str
name: str
@@ -42,6 +49,7 @@ class UserModel(BaseModel):
created_at: int # timestamp in epoch
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
####################

View File

@@ -8,6 +8,8 @@ from pydantic import BaseModel
import time
import uuid
from config import BannerModel
from apps.webui.models.users import Users
from utils.utils import (
@@ -57,3 +59,31 @@ async def set_global_default_suggestions(
data = form_data.model_dump()
request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
############################
# SetBanners
############################
class SetBannersForm(BaseModel):
banners: List[BannerModel]
@router.post("/banners", response_model=List[BannerModel])
async def set_banners(
request: Request,
form_data: SetBannersForm,
user=Depends(get_admin_user),
):
data = form_data.model_dump()
request.app.state.config.BANNERS = data["banners"]
return request.app.state.config.BANNERS
@router.get("/banners", response_model=List[BannerModel])
async def get_banners(
request: Request,
user=Depends(get_current_user),
):
return request.app.state.config.BANNERS

View File

@@ -9,7 +9,13 @@ import time
import uuid
import logging
from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
UserRoleUpdateForm,
UserSettings,
Users,
)
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
@@ -68,6 +74,42 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
)
############################
# GetUserSettingsBySessionUser
############################
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.settings
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserSettingsBySessionUser
############################
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user:
return user.settings
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# GetUserById
############################
@@ -81,6 +123,8 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id)

View File

@@ -8,6 +8,8 @@ from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from pydantic import BaseModel
from typing import Optional
from pathlib import Path
import json
@@ -166,10 +168,10 @@ CHANGELOG = changelog_json
####################################
# WEBUI_VERSION
# WEBUI_BUILD_HASH
####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
@@ -566,6 +568,27 @@ WEBHOOK_URL = PersistentConfig(
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
ENABLE_COMMUNITY_SHARING = PersistentConfig(
"ENABLE_COMMUNITY_SHARING",
"ui.enable_community_sharing",
os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
)
class BannerModel(BaseModel):
id: str
type: str
title: Optional[str] = None
content: str
dismissible: bool
timestamp: int
WEBUI_BANNERS = PersistentConfig(
"WEBUI_BANNERS",
"ui.banners",
[BannerModel(**banner) for banner in json.loads("[]")],
)
####################################
# WEBUI_SECRET_KEY
####################################

View File

@@ -56,6 +56,7 @@ from config import (
ENABLE_ADMIN_EXPORT,
RAG_WEB_SEARCH_ENABLED,
AppConfig,
WEBUI_BUILD_HASH,
)
from constants import ERROR_MESSAGES
@@ -85,7 +86,8 @@ print(
|_|
v{VERSION} - building the best open-source AI user interface.
v{VERSION} - building the best open-source AI user interface.
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
https://github.com/open-webui/open-webui
"""
)
@@ -357,15 +359,18 @@ async def get_app_config():
"status": True,
"name": WEBUI_NAME,
"version": VERSION,
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"default_locale": default_locale,
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"enable_websearch": RAG_WEB_SEARCH_ENABLED,
"features": {
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_websearch": RAG_WEB_SEARCH_ENABLED,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
},
}
@@ -416,6 +421,19 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
}
@app.get("/api/community_sharing", response_model=bool)
async def get_community_sharing_status(request: Request, user=Depends(get_admin_user)):
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
@app.get("/api/community_sharing/toggle", response_model=bool)
async def toggle_community_sharing(request: Request, user=Depends(get_admin_user)):
webui_app.state.config.ENABLE_COMMUNITY_SHARING = (
not webui_app.state.config.ENABLE_COMMUNITY_SHARING
)
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
@app.get("/api/version")
async def get_app_config():
return {