mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into dev
This commit is contained in:
@@ -709,8 +709,8 @@ def save_docs_to_vector_db(
|
||||
if overwrite:
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||||
log.info(f"deleting existing collection {collection_name}")
|
||||
|
||||
if add is False:
|
||||
elif add is False:
|
||||
log.info(f"collection {collection_name} already exists, overwrite is False and add is False")
|
||||
return True
|
||||
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
|
||||
@@ -385,6 +385,8 @@ def get_rag_context(
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
if context:
|
||||
if "data" in file:
|
||||
del file["data"]
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
contexts = []
|
||||
@@ -401,11 +403,8 @@ def get_rag_context(
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
(", ".join(file_names) + ":\n\n")
|
||||
if file_names
|
||||
else ""
|
||||
((", ".join(file_names) + ":\n\n") if file_names else "")
|
||||
+ "\n\n".join(
|
||||
[text for text in context["documents"][0] if text is not None]
|
||||
)
|
||||
@@ -423,7 +422,9 @@ def get_rag_context(
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
print(contexts, citations)
|
||||
print("contexts", contexts)
|
||||
print("citations", citations)
|
||||
|
||||
return contexts, citations
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.apps.webui.routers import (
|
||||
auths,
|
||||
chats,
|
||||
folders,
|
||||
configs,
|
||||
files,
|
||||
functions,
|
||||
@@ -110,6 +111,7 @@ app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
app.include_router(folders.router, prefix="/folders", tags=["folders"])
|
||||
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
@@ -33,6 +33,7 @@ class Chat(Base):
|
||||
pinned = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
meta = Column(JSON, server_default="{}")
|
||||
folder_id = Column(Text, nullable=True)
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
@@ -51,6 +52,7 @@ class ChatModel(BaseModel):
|
||||
pinned: Optional[bool] = False
|
||||
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
####################
|
||||
@@ -61,10 +63,12 @@ class ChatModel(BaseModel):
|
||||
class ChatForm(BaseModel):
|
||||
chat: dict
|
||||
|
||||
|
||||
class ChatTitleMessagesForm(BaseModel):
|
||||
title: str
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class ChatTitleForm(BaseModel):
|
||||
title: str
|
||||
|
||||
@@ -80,6 +84,7 @@ class ChatResponse(BaseModel):
|
||||
archived: bool
|
||||
pinned: Optional[bool] = False
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatTitleIdResponse(BaseModel):
|
||||
@@ -252,14 +257,18 @@ class ChatTable:
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
all_chats = (
|
||||
query.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit).offset(skip)
|
||||
.all()
|
||||
)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_title_id_list_by_user_id(
|
||||
@@ -270,7 +279,9 @@ class ChatTable:
|
||||
limit: Optional[int] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
@@ -361,7 +372,7 @@ class ChatTable:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, pinned=True)
|
||||
.filter_by(user_id=user_id, pinned=True, archived=False)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
@@ -387,9 +398,25 @@ class ChatTable:
|
||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||
"""
|
||||
search_text = search_text.lower().strip()
|
||||
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
|
||||
|
||||
search_text_words = search_text.split(" ")
|
||||
|
||||
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
|
||||
tag_ids = [
|
||||
word.replace("tag:", "").replace(" ", "_").lower()
|
||||
for word in search_text_words
|
||||
if word.startswith("tag:")
|
||||
]
|
||||
|
||||
search_text_words = [
|
||||
word for word in search_text_words if not word.startswith("tag:")
|
||||
]
|
||||
|
||||
search_text = " ".join(search_text_words)
|
||||
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
@@ -418,6 +445,26 @@ class ChatTable:
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
text(
|
||||
f"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.meta, '$.tags') AS tag
|
||||
WHERE tag.value = :tag_id_{tag_idx}
|
||||
)
|
||||
"""
|
||||
).params(**{f"tag_id_{tag_idx}": tag_id})
|
||||
for tag_idx, tag_id in enumerate(tag_ids)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
elif dialect_name == "postgresql":
|
||||
# PostgreSQL relies on proper JSON query for search
|
||||
query = query.filter(
|
||||
@@ -436,6 +483,25 @@ class ChatTable:
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
text(
|
||||
f"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements_text(Chat.meta->'tags') AS tag
|
||||
WHERE tag = :tag_id_{tag_idx}
|
||||
)
|
||||
"""
|
||||
).params(**{f"tag_id_{tag_idx}": tag_id})
|
||||
for tag_idx, tag_id in enumerate(tag_ids)
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
@@ -444,9 +510,34 @@ class ChatTable:
|
||||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
print(len(all_chats))
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id).all()
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def update_chat_folder_id_by_id_and_user_id(
|
||||
self, id: str, user_id: str, folder_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.folder_id = folder_id
|
||||
chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
@@ -498,7 +589,7 @@ class ChatTable:
|
||||
if tag_id not in chat.meta.get("tags", []):
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": chat.meta.get("tags", []) + [tag_id],
|
||||
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
|
||||
}
|
||||
|
||||
db.commit()
|
||||
@@ -509,7 +600,7 @@ class ChatTable:
|
||||
|
||||
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
|
||||
with get_db() as db: # Assuming `get_db()` returns a session object
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
# Normalize the tag_name for consistency
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
@@ -555,7 +646,7 @@ class ChatTable:
|
||||
tags = [tag for tag in tags if tag != tag_id]
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": tags,
|
||||
"tags": list(set(tags)),
|
||||
}
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
225
backend/open_webui/apps/webui/models/folders.py
Normal file
225
backend/open_webui/apps/webui/models/folders.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Folder DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Folder(Base):
|
||||
__tablename__ = "folder"
|
||||
id = Column(Text, primary_key=True)
|
||||
parent_id = Column(Text, nullable=True)
|
||||
user_id = Column(Text)
|
||||
name = Column(Text)
|
||||
items = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
is_expanded = Column(Boolean, default=False)
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FolderModel(BaseModel):
|
||||
id: str
|
||||
parent_id: Optional[str] = None
|
||||
user_id: str
|
||||
name: str
|
||||
items: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
is_expanded: bool = False
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FolderForm(BaseModel):
|
||||
name: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FolderTable:
|
||||
def insert_new_folder(
|
||||
self, user_id: str, name: str, parent_id: Optional[str] = None
|
||||
) -> Optional[FolderModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
folder = FolderModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"parent_id": parent_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
result = Folder(**folder.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return FolderModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder).filter_by(user_id=user_id).all()
|
||||
]
|
||||
|
||||
def get_folder_by_parent_id_and_user_id_and_name(
|
||||
self, parent_id: Optional[str], user_id: str, name: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Check if folder exists
|
||||
folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.filter(Folder.name.ilike(name))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
|
||||
return None
|
||||
|
||||
def get_folders_by_parent_id_and_user_id(
|
||||
self, parent_id: Optional[str], user_id: str
|
||||
) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.all()
|
||||
]
|
||||
|
||||
def update_folder_parent_id_by_id_and_user_id(
|
||||
self,
|
||||
id: str,
|
||||
user_id: str,
|
||||
parent_id: str,
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
folder.parent_id = parent_id
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def update_folder_name_by_id_and_user_id(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
existing_folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_folder:
|
||||
return None
|
||||
|
||||
folder.name = name
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def update_folder_is_expanded_by_id_and_user_id(
|
||||
self, id: str, user_id: str, is_expanded: bool
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
folder.is_expanded = is_expanded
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
db.delete(folder)
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_folder: {e}")
|
||||
return False
|
||||
|
||||
|
||||
Folders = FolderTable()
|
||||
@@ -1,10 +1,13 @@
|
||||
import re
|
||||
import uuid
|
||||
import time
|
||||
import datetime
|
||||
|
||||
from open_webui.apps.webui.models.auths import (
|
||||
AddUserForm,
|
||||
ApiKey,
|
||||
Auths,
|
||||
Token,
|
||||
SigninForm,
|
||||
SigninResponse,
|
||||
SignupForm,
|
||||
@@ -34,6 +37,7 @@ from open_webui.utils.utils import (
|
||||
get_password_hash,
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -42,25 +46,44 @@ router = APIRouter()
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=UserResponse)
|
||||
class SessionUserResponse(Token, UserResponse):
|
||||
expires_at: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/", response_model=SessionUserResponse)
|
||||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
):
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
|
||||
datetime_expires_at = (
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
@@ -119,7 +142,7 @@ async def update_password(
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/signin", response_model=SigninResponse)
|
||||
@router.post("/signin", response_model=SessionUserResponse)
|
||||
async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||
@@ -161,23 +184,37 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
|
||||
|
||||
if user:
|
||||
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
|
||||
datetime_expires_at = (
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
@@ -193,7 +230,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/signup", response_model=SigninResponse)
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
@@ -233,18 +270,30 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
)
|
||||
|
||||
if user:
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
|
||||
datetime_expires_at = (
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
@@ -261,6 +310,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
|
||||
@@ -114,13 +114,24 @@ async def search_user_chats(
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
return [
|
||||
chat_list = [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_user_id_and_search_text(
|
||||
user.id, text, skip=skip, limit=limit
|
||||
)
|
||||
]
|
||||
|
||||
# Delete tag if no chat is found
|
||||
words = text.strip().split(" ")
|
||||
if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
|
||||
tag_id = words[0].replace("tag:", "")
|
||||
if len(chat_list) == 0:
|
||||
if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
|
||||
log.debug(f"deleting tag: {tag_id}")
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||
|
||||
return chat_list
|
||||
|
||||
|
||||
############################
|
||||
# GetPinnedChats
|
||||
@@ -315,7 +326,13 @@ async def update_chat_by_id(
|
||||
@router.delete("/{id}", response_model=bool)
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
result = Chats.delete_chat_by_id(id)
|
||||
|
||||
return result
|
||||
else:
|
||||
if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
|
||||
@@ -326,6 +343,11 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id(id)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
|
||||
return result
|
||||
|
||||
@@ -397,6 +419,20 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
|
||||
# Delete tags if chat is archived
|
||||
if chat.archived:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
|
||||
log.debug(f"deleting tag: {tag_id}")
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||
else:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
|
||||
if tag is None:
|
||||
log.debug(f"inserting tag: {tag_id}")
|
||||
tag = Tags.insert_new_tag(tag_id, user.id)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -455,6 +491,31 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChatFolderIdById
|
||||
############################
|
||||
|
||||
|
||||
class ChatFolderIdForm(BaseModel):
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
|
||||
async def update_chat_folder_id_by_id(
|
||||
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.update_chat_folder_id_by_id_and_user_id(
|
||||
id, user.id, form_data.folder_id
|
||||
)
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChatTagsById
|
||||
############################
|
||||
|
||||
259
backend/open_webui/apps/webui/routers/folders.py
Normal file
259
backend/open_webui/apps/webui/routers/folders.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import mimetypes
|
||||
|
||||
|
||||
from open_webui.apps.webui.models.folders import (
|
||||
FolderForm,
|
||||
FolderModel,
|
||||
Folders,
|
||||
)
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# Get Folders
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[FolderModel])
|
||||
async def get_folders(user=Depends(get_verified_user)):
|
||||
folders = Folders.get_folders_by_user_id(user.id)
|
||||
|
||||
return [
|
||||
{
|
||||
**folder.model_dump(),
|
||||
"items": {
|
||||
"chats": [
|
||||
{"title": chat.title, "id": chat.id}
|
||||
for chat in Chats.get_chats_by_folder_id_and_user_id(
|
||||
folder.id, user.id
|
||||
)
|
||||
]
|
||||
},
|
||||
}
|
||||
for folder in folders
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# Create Folder
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/")
|
||||
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
None, user.id, form_data.name
|
||||
)
|
||||
|
||||
if folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
|
||||
)
|
||||
|
||||
try:
|
||||
folder = Folders.insert_new_folder(user.id, form_data.name)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error("Error creating folder")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating folder"),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get Folders By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[FolderModel])
|
||||
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
return folder
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update Folder Name By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update")
|
||||
async def update_folder_name_by_id(
|
||||
id: str, form_data: FolderForm, user=Depends(get_verified_user)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
folder.parent_id, user.id, form_data.name
|
||||
)
|
||||
if existing_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
|
||||
)
|
||||
|
||||
try:
|
||||
folder = Folders.update_folder_name_by_id_and_user_id(
|
||||
id, user.id, form_data.name
|
||||
)
|
||||
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error updating folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update Folder Parent Id By Id
|
||||
############################
|
||||
|
||||
|
||||
class FolderParentIdForm(BaseModel):
|
||||
parent_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{id}/update/parent")
|
||||
async def update_folder_parent_id_by_id(
|
||||
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
form_data.parent_id, user.id, folder.name
|
||||
)
|
||||
|
||||
if existing_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
|
||||
)
|
||||
|
||||
try:
|
||||
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||
id, user.id, form_data.parent_id
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error updating folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update Folder Is Expanded By Id
|
||||
############################
|
||||
|
||||
|
||||
class FolderIsExpandedForm(BaseModel):
|
||||
is_expanded: bool
|
||||
|
||||
|
||||
@router.post("/{id}/update/expanded")
|
||||
async def update_folder_is_expanded_by_id(
|
||||
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
|
||||
id, user.id, form_data.is_expanded
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error updating folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Delete Folder By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
result = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||
if result:
|
||||
# Delete all chats in the folder
|
||||
chats = Chats.get_chats_by_folder_id_and_user_id(id, user.id)
|
||||
for chat in chats:
|
||||
Chats.delete_chat_by_id(chat.id, user.id)
|
||||
|
||||
return result
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
@@ -1042,7 +1042,7 @@ CHUNK_OVERLAP = PersistentConfig(
|
||||
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
|
||||
|
||||
<context>
|
||||
[context]
|
||||
{{CONTEXT}}
|
||||
</context>
|
||||
|
||||
<rules>
|
||||
@@ -1055,7 +1055,7 @@ DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and r
|
||||
</rules>
|
||||
|
||||
<user_query>
|
||||
[query]
|
||||
{{QUERY}}
|
||||
</user_query>
|
||||
"""
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ class ERROR_MESSAGES(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
DEFAULT = lambda err="": f"Something went wrong :/\n[ERROR: {err if err else ''}]"
|
||||
DEFAULT = (
|
||||
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + err + "]"}'
|
||||
)
|
||||
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
|
||||
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
|
||||
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
|
||||
|
||||
@@ -135,7 +135,7 @@ def upgrade():
|
||||
tags = chat_updates[chat_id]["meta"].get("tags", [])
|
||||
tags.append(tag_name)
|
||||
|
||||
chat_updates[chat_id]["meta"]["tags"] = tags
|
||||
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
|
||||
|
||||
# Update chats based on accumulated changes
|
||||
for chat_id, updates in chat_updates.items():
|
||||
|
||||
@@ -28,25 +28,39 @@ def upgrade():
|
||||
unique_constraints = inspector.get_unique_constraints("tag")
|
||||
existing_indexes = inspector.get_indexes("tag")
|
||||
|
||||
print(existing_pk, unique_constraints)
|
||||
print(f"Primary Key: {existing_pk}")
|
||||
print(f"Unique Constraints: {unique_constraints}")
|
||||
print(f"Indexes: {existing_indexes}")
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
# Drop unique constraints that could conflict with new primary key
|
||||
# Drop existing primary key constraint if it exists
|
||||
if existing_pk and existing_pk.get("constrained_columns"):
|
||||
pk_name = existing_pk.get("name")
|
||||
if pk_name:
|
||||
print(f"Dropping primary key constraint: {pk_name}")
|
||||
batch_op.drop_constraint(pk_name, type_="primary")
|
||||
|
||||
# Now create the new primary key with the combination of 'id' and 'user_id'
|
||||
print("Creating new primary key with 'id' and 'user_id'.")
|
||||
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
|
||||
|
||||
# Drop unique constraints that could conflict with the new primary key
|
||||
for constraint in unique_constraints:
|
||||
if constraint["name"] == "uq_id_user_id":
|
||||
if (
|
||||
constraint["name"] == "uq_id_user_id"
|
||||
): # Adjust this name according to what is actually returned by the inspector
|
||||
print(f"Dropping unique constraint: {constraint['name']}")
|
||||
batch_op.drop_constraint(constraint["name"], type_="unique")
|
||||
|
||||
for index in existing_indexes:
|
||||
if index["unique"]:
|
||||
# Drop the unique index
|
||||
batch_op.drop_index(index["name"])
|
||||
|
||||
# Drop existing primary key constraint if it exists
|
||||
if existing_pk and existing_pk.get("constrained_columns"):
|
||||
batch_op.drop_constraint(existing_pk["name"], type_="primary")
|
||||
|
||||
# Immediately after dropping the old primary key, create the new one
|
||||
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
|
||||
if not any(
|
||||
constraint["name"] == index["name"]
|
||||
for constraint in unique_constraints
|
||||
):
|
||||
# You are attempting to drop unique indexes
|
||||
print(f"Dropping unique index: {index['name']}")
|
||||
batch_op.drop_index(index["name"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Add folder table
|
||||
|
||||
Revision ID: c69f45358db4
|
||||
Revises: 3ab32c4b8f59
|
||||
Create Date: 2024-10-16 02:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c69f45358db4"
|
||||
down_revision = "3ab32c4b8f59"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"folder",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("items", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "user_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat",
|
||||
sa.Column("folder_id", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("chat", "folder_id")
|
||||
|
||||
op.drop_table("folder")
|
||||
@@ -7,7 +7,7 @@ import jwt
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import WEBUI_SECRET_KEY
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi import Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from passlib.context import CryptContext
|
||||
|
||||
|
||||
@@ -90,3 +90,5 @@ duckduckgo-search~=6.2.13
|
||||
docker~=7.1.0
|
||||
pytest~=8.3.2
|
||||
pytest-docker~=3.1.1
|
||||
|
||||
googleapis-common-protos==1.63.2
|
||||
|
||||
Reference in New Issue
Block a user