mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'refs/heads/dev' into feat/sqlalchemy-instead-of-peewee
# Conflicts: # backend/apps/webui/models/functions.py # backend/apps/webui/routers/chats.py
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
"""Peewee migrations -- 017_add_user_oauth_sub.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"function",
|
||||
is_global=pw.BooleanField(default=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("function", "is_global")
|
||||
@@ -33,6 +33,7 @@ class Function(Base):
|
||||
meta = Column(JSONField)
|
||||
valves = Column(JSONField)
|
||||
is_active = Column(Boolean)
|
||||
is_global = Column(Boolean)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
@@ -50,6 +51,7 @@ class FunctionModel(BaseModel):
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool = False
|
||||
is_global: bool = False
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@@ -68,6 +70,7 @@ class FunctionResponse(BaseModel):
|
||||
name: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool
|
||||
is_global: bool
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@@ -146,6 +149,16 @@ class FunctionsTable:
|
||||
for function in Session.query(Function).filter_by(type=type).all()
|
||||
]
|
||||
|
||||
def get_global_filter_functions(self) -> List[FunctionModel]:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(
|
||||
Function.type == "filter",
|
||||
Function.is_active == True,
|
||||
Function.is_global == True,
|
||||
)
|
||||
]
|
||||
|
||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
function = Session.get(Function, id)
|
||||
|
||||
@@ -136,6 +136,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
if not Users.get_user_by_email(trusted_email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
|
||||
),
|
||||
@@ -153,6 +154,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(email=admin_email, password=admin_password, name="User"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from fastapi import Depends, Request, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from utils.utils import get_verified_user, get_admin_user
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
@@ -44,7 +43,7 @@ router = APIRouter()
|
||||
@router.get("/", response_model=List[ChatTitleIdResponse])
|
||||
@router.get("/list", response_model=List[ChatTitleIdResponse])
|
||||
async def get_session_user_chat_list(
|
||||
user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
return Chats.get_chat_list_by_user_id(user.id, skip, limit)
|
||||
|
||||
@@ -55,7 +54,7 @@ async def get_session_user_chat_list(
|
||||
|
||||
|
||||
@router.delete("/", response_model=bool)
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if (
|
||||
user.role == "user"
|
||||
@@ -93,7 +92,7 @@ async def get_user_chat_list_by_user_id(
|
||||
|
||||
|
||||
@router.post("/new", response_model=Optional[ChatResponse])
|
||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
chat = Chats.insert_new_chat(user.id, form_data)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
@@ -110,7 +109,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/all", response_model=List[ChatResponse])
|
||||
async def get_user_chats(user=Depends(get_current_user)):
|
||||
async def get_user_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_chats_by_user_id(user.id)
|
||||
@@ -123,7 +122,7 @@ async def get_user_chats(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/all/archived", response_model=List[ChatResponse])
|
||||
async def get_user_archived_chats(user=Depends(get_current_user)):
|
||||
async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
||||
@@ -155,7 +154,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||
|
||||
@router.get("/archived", response_model=List[ChatTitleIdResponse])
|
||||
async def get_archived_session_user_chat_list(
|
||||
user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
|
||||
|
||||
@@ -166,7 +165,7 @@ async def get_archived_session_user_chat_list(
|
||||
|
||||
|
||||
@router.post("/archive/all", response_model=bool)
|
||||
async def archive_all_chats(user=Depends(get_current_user)):
|
||||
async def archive_all_chats(user=Depends(get_verified_user)):
|
||||
return Chats.archive_all_chats_by_user_id(user.id)
|
||||
|
||||
|
||||
@@ -176,7 +175,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
||||
if user.role == "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -208,7 +207,7 @@ class TagNameForm(BaseModel):
|
||||
|
||||
@router.post("/tags", response_model=List[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_tag_name(
|
||||
form_data: TagNameForm, user=Depends(get_current_user)
|
||||
form_data: TagNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
print(form_data)
|
||||
@@ -233,7 +232,7 @@ async def get_user_chat_list_by_tag_name(
|
||||
|
||||
|
||||
@router.get("/tags/all", response_model=List[TagModel])
|
||||
async def get_all_tags(user=Depends(get_current_user)):
|
||||
async def get_all_tags(user=Depends(get_verified_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
@@ -250,7 +249,7 @@ async def get_all_tags(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[ChatResponse])
|
||||
async def get_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
|
||||
if chat:
|
||||
@@ -268,7 +267,7 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
@router.post("/{id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_by_id(
|
||||
id: str, form_data: ChatForm, user=Depends(get_current_user)
|
||||
id: str, form_data: ChatForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
@@ -289,7 +288,7 @@ async def update_chat_by_id(
|
||||
|
||||
|
||||
@router.delete("/{id}", response_model=bool)
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
|
||||
if user.role == "admin":
|
||||
result = Chats.delete_chat_by_id(id)
|
||||
@@ -311,7 +310,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
|
||||
|
||||
|
||||
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
|
||||
@@ -337,7 +336,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
@@ -354,7 +353,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
if chat.share_id:
|
||||
@@ -386,7 +385,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.delete("/{id}/share", response_model=Optional[bool])
|
||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
if not chat.share_id:
|
||||
@@ -409,7 +408,7 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/{id}/tags", response_model=List[TagModel])
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if tags != None:
|
||||
@@ -427,7 +426,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
|
||||
async def add_chat_tag_by_id(
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
|
||||
):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
@@ -454,9 +453,7 @@ async def add_chat_tag_by_id(
|
||||
|
||||
@router.delete("/{id}/tags", response_model=Optional[bool])
|
||||
async def delete_chat_tag_by_id(
|
||||
id: str,
|
||||
form_data: ChatIdTagForm,
|
||||
user=Depends(get_current_user),
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
|
||||
):
|
||||
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||
form_data.tag_name, id, user.id
|
||||
@@ -476,7 +473,7 @@ async def delete_chat_tag_by_id(
|
||||
|
||||
|
||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -14,7 +14,7 @@ from apps.webui.models.users import Users
|
||||
|
||||
from utils.utils import (
|
||||
get_password_hash,
|
||||
get_current_user,
|
||||
get_verified_user,
|
||||
get_admin_user,
|
||||
create_token,
|
||||
)
|
||||
@@ -84,6 +84,6 @@ async def set_banners(
|
||||
@router.get("/banners", response_model=List[BannerModel])
|
||||
async def get_banners(
|
||||
request: Request,
|
||||
user=Depends(get_current_user),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
return request.app.state.config.BANNERS
|
||||
|
||||
@@ -14,7 +14,7 @@ from apps.webui.models.documents import (
|
||||
DocumentResponse,
|
||||
)
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from utils.utils import get_verified_user, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
@@ -25,7 +25,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[DocumentResponse])
|
||||
async def get_documents(user=Depends(get_current_user)):
|
||||
async def get_documents(user=Depends(get_verified_user)):
|
||||
docs = [
|
||||
DocumentResponse(
|
||||
**{
|
||||
@@ -74,7 +74,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
|
||||
|
||||
|
||||
@router.get("/doc", response_model=Optional[DocumentResponse])
|
||||
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
|
||||
async def get_doc_by_name(name: str, user=Depends(get_verified_user)):
|
||||
doc = Documents.get_doc_by_name(name)
|
||||
|
||||
if doc:
|
||||
@@ -106,7 +106,7 @@ class TagDocumentForm(BaseModel):
|
||||
|
||||
|
||||
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
|
||||
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
|
||||
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)):
|
||||
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
|
||||
|
||||
if doc:
|
||||
|
||||
@@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ToggleGlobalById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
id, {"is_global": not function.is_global}
|
||||
)
|
||||
|
||||
if function:
|
||||
return function
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateFunctionById
|
||||
############################
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
|
||||
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from utils.utils import get_verified_user, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
@@ -19,7 +19,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[PromptModel])
|
||||
async def get_prompts(user=Depends(get_current_user)):
|
||||
async def get_prompts(user=Depends(get_verified_user)):
|
||||
return Prompts.get_prompts()
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
|
||||
|
||||
|
||||
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
|
||||
if prompt:
|
||||
|
||||
Reference in New Issue
Block a user