Merge branch 'upstream-dev' into dev

This commit is contained in:
Jannik Streidl
2024-10-12 15:18:59 +02:00
103 changed files with 2956 additions and 3023 deletions

View File

@@ -450,7 +450,7 @@ def transcribe(file_path):
except Exception:
error_detail = f"External: {e}"
raise error_detail
raise Exception(error_detail)
@app.post("/transcriptions")

View File

@@ -125,22 +125,34 @@ async def comfyui_generate_image(
workflow[node_id]["inputs"][node.key] = model
elif node.type == "prompt":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["text"] = payload.prompt
workflow[node_id]["inputs"][
node.key if node.key else "text"
] = payload.prompt
elif node.type == "negative_prompt":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["text"] = payload.negative_prompt
workflow[node_id]["inputs"][
node.key if node.key else "text"
] = payload.negative_prompt
elif node.type == "width":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["width"] = payload.width
workflow[node_id]["inputs"][
node.key if node.key else "width"
] = payload.width
elif node.type == "height":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["height"] = payload.height
workflow[node_id]["inputs"][
node.key if node.key else "height"
] = payload.height
elif node.type == "n":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["batch_size"] = payload.n
workflow[node_id]["inputs"][
node.key if node.key else "batch_size"
] = payload.n
elif node.type == "steps":
for node_id in node.node_ids:
workflow[node_id]["inputs"]["steps"] = payload.steps
workflow[node_id]["inputs"][
node.key if node.key else "steps"
] = payload.steps
elif node.type == "seed":
seed = (
payload.seed

View File

@@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel):
model: str
input: str
truncate: Optional[bool]
input: list[str]|str
truncate: Optional[bool] = None
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
@@ -560,48 +560,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings")
@@ -611,48 +570,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings(
@@ -692,7 +610,64 @@ def generate_ollama_embeddings(
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
return data["embedding"]
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
def generate_ollama_batch_embeddings(
form_data: GenerateEmbedForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
data = r.json()
log.info(f"generate_ollama_batch_embeddings {data}")
if "embeddings" in data:
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
@@ -788,8 +763,7 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
):
payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"{payload = }")
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
if "metadata" in payload:
del payload["metadata"]
@@ -824,7 +798,7 @@ async def generate_chat_completion(
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
log.debug(payload)
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
return await post_streaming_url(
f"{url}/api/chat",

View File

@@ -63,7 +63,7 @@ from open_webui.config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
app.add_middleware(
@@ -267,7 +267,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
}
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@app.post("/embedding/update")
@@ -320,11 +320,7 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -334,17 +330,17 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -645,7 +641,7 @@ def save_docs_to_vector_db(
filter={"hash": metadata["hash"]},
)
if result:
if result is not None:
existing_doc_ids = result.ids[0]
if existing_doc_ids:
log.info(f"Document with hash {metadata['hash']} already exists")
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
embeddings = embedding_function(
@@ -767,7 +763,7 @@ def process_file(
collection_name=f"file-{file.id}", filter={"file_id": file.id}
)
if len(result.ids[0]) > 0:
if result is not None and len(result.ids[0]) > 0:
docs = [
Document(
page_content=result.documents[0][idx],

View File

@@ -12,8 +12,8 @@ from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@@ -193,7 +193,8 @@ def query_collection(
k=k,
query_embedding=query_embedding,
)
results.append(result.model_dump())
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
@@ -265,39 +266,27 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
batch_size,
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
func = lambda query: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
)
def generate_multiple(query, f):
def generate_multiple(query, func):
if isinstance(query, list):
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings
else:
return f(query)
return func(query)
return lambda query: generate_multiple(query, func)
@@ -445,20 +434,6 @@ def get_model_path(model: str, update_model: bool = False):
return model
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
@@ -482,6 +457,33 @@ def generate_openai_batch_embeddings(
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence

View File

@@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient

View File

@@ -0,0 +1,176 @@
from typing import Optional
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _create_collection(self, collection_name: str, dimension: int):
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
)
print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(
collection_name=collection_name
):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
def _create_points(self, items: list[VectorItem]):
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"]
},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
def delete_collection(self, collection_name: str):
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query=vectors[0],
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]]
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
if not self.has_collection(collection_name):
return None
try:
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query_filter=models.Filter(should=field_conditions),
limit=limit,
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
if ids:
for id_value in ids:
field_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
),
elif filter:
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
),
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(
must=field_conditions
)
),
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)

View File

@@ -4,8 +4,13 @@ import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Chat DB Schema
@@ -18,13 +23,16 @@ class Chat(Base):
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(Text) # Save Chat JSON as Text
chat = Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
class ChatModel(BaseModel):
@@ -33,13 +41,16 @@ class ChatModel(BaseModel):
id: str
user_id: str
title: str
chat: str
chat: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
archived: bool = False
pinned: Optional[bool] = False
meta: dict = {}
####################
@@ -64,6 +75,8 @@ class ChatResponse(BaseModel):
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
archived: bool
pinned: Optional[bool] = False
meta: dict = {}
class ChatTitleIdResponse(BaseModel):
@@ -86,7 +99,7 @@ class ChatTable:
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
@@ -101,14 +114,14 @@ class ChatTable:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_obj = db.get(Chat, id)
chat_obj.chat = json.dumps(chat)
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
chat_obj.updated_at = int(time.time())
chat_item = db.get(Chat, id)
chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time())
db.commit()
db.refresh(chat_obj)
db.refresh(chat_item)
return ChatModel.model_validate(chat_obj)
return ChatModel.model_validate(chat_item)
except Exception:
return None
@@ -182,11 +195,24 @@ class ChatTable:
except Exception:
return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
@@ -249,10 +275,10 @@ class ChatTable:
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
if limit:
query = query.limit(limit)
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
@@ -328,6 +354,15 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
@@ -337,6 +372,207 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
query = query.order_by(Chat.updated_at.desc())
# Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
all_chats = query.all()
print("all_chats", all_chats)
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": chat.meta.get("tags", []) + [tag_id],
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
return count
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower()
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": tags,
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:

View File

@@ -50,6 +50,14 @@ class FileModel(BaseModel):
####################
class FileMeta(BaseModel):
name: Optional[str] = None
content_type: Optional[str] = None
size: Optional[int] = None
model_config = ConfigDict(extra="allow")
class FileModelResponse(BaseModel):
id: str
user_id: str
@@ -57,12 +65,19 @@ class FileModelResponse(BaseModel):
filename: str
data: Optional[dict] = None
meta: dict
meta: FileMeta
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class FileMetadataResponse(BaseModel):
id: str
meta: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class FileForm(BaseModel):
id: str
hash: Optional[str] = None
@@ -104,6 +119,19 @@ class FilesTable:
except Exception:
return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
file = db.get(File, id)
return FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
except Exception:
return None
def get_files(self) -> list[FileModel]:
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
@@ -118,6 +146,21 @@ class FilesTable:
.all()
]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
with get_db() as db:
return [
FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
with get_db() as db:
return [

View File

@@ -6,6 +6,10 @@ import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel):
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeForm(BaseModel):
name: str

View File

@@ -4,53 +4,32 @@ import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, JSON
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tag DB Schema
####################
class Tag(Base):
__tablename__ = "tag"
id = Column(String, primary_key=True)
name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
class ChatIdTag(Base):
__tablename__ = "chatidtag"
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
meta = Column(JSON, nullable=True)
class TagModel(BaseModel):
id: str
name: str
user_id: str
data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
tag_name: str
chat_id: str
user_id: str
timestamp: int
meta: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
@@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel):
####################
class ChatIdTagForm(BaseModel):
tag_name: str
class TagChatIdForm(BaseModel):
name: str
chat_id: str
class TagChatIdsResponse(BaseModel):
chat_ids: list[str]
class ChatTagsResponse(BaseModel):
tags: list[str]
class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
id = str(uuid.uuid4())
id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag(**tag.model_dump())
@@ -93,170 +64,38 @@ class TagTable:
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
id = name.replace(" ", "_").lower()
with get_db() as db:
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception:
return None
def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag is None:
tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
**{
"id": id,
"user_id": user_id,
"chat_id": form_data.chat_id,
"tag_name": tag.name,
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except Exception:
return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> list[TagModel]:
def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
]
def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> list[ChatIdTagModel]:
with get_db() as db:
return [
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
with get_db() as db:
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
try:
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
id = name.replace(" ", "_").lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
tag.tag_name, chat_id, user_id
)
return True
Tags = TagTable()

View File

@@ -18,6 +18,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
@@ -27,6 +29,7 @@ from open_webui.utils.utils import (
create_api_key,
create_token,
get_admin_user,
get_verified_user,
get_current_user,
get_password_hash,
)
@@ -53,6 +56,8 @@ async def get_session_user(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
@@ -71,7 +76,7 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserResponse)
async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
):
if session_user:
user = Users.update_user_by_id(
@@ -166,6 +171,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
@@ -236,6 +243,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
if request.app.state.config.WEBHOOK_URL:

View File

@@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
Chats,
ChatTitleIdResponse,
)
from open_webui.apps.webui.models.tags import (
ChatIdTagForm,
ChatIdTagModel,
TagModel,
Tags,
)
from open_webui.apps.webui.models.tags import TagModel, Tags
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@@ -95,7 +91,7 @@ async def get_user_chat_list_by_user_id(
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
@@ -108,10 +104,46 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
############################
@router.get("/search", response_model=list[ChatTitleIdResponse])
async def search_user_chats(
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
):
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
return [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit
)
]
############################
# GetPinnedChats
############################
@router.get("/pinned", response_model=list[ChatResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
############################
# GetChats
############################
@router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id(user.id)
]
@@ -124,11 +156,28 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=list[ChatResponse])
async def get_user_archived_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
ChatResponse(**chat.model_dump())
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllTags
############################
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetAllChatsInDB
############################
@@ -141,10 +190,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats()
]
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
############################
@@ -187,7 +233,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -199,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
############################
class TagNameForm(BaseModel):
class TagForm(BaseModel):
name: str
class TagFilterForm(TagForm):
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user)
form_data: TagFilterForm, user=Depends(get_verified_user)
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
chats = Chats.get_chat_list_by_user_id_and_tag_name(
user.id, form_data.name, form_data.skip, form_data.limit
)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=list[TagModel])
async def get_all_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatById
############################
@@ -251,7 +278,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -269,10 +297,9 @@ async def update_chat_by_id(
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
updated_chat = {**chat.chat, **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -303,25 +330,57 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
return result
############################
# GetPinnedStatusById
############################
@router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return chat.pinned
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# PinChatById
############################
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_pinned_by_id(id)
return chat
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat_body = json.loads(chat.chat)
updated_chat = {
**chat_body,
**chat.chat,
"originalChatId": chat.id,
"branchPointMessageId": chat_body["history"]["currentId"],
"branchPointMessageId": chat.chat["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -333,12 +392,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -356,9 +415,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
return ChatResponse(**shared_chat.model_dump())
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat:
@@ -366,10 +423,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
return ChatResponse(**shared_chat.model_dump())
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -407,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None:
return tags
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -422,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
@router.post("/{id}/tags", response_model=list[TagModel])
async def add_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower()
if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(user.id, form_data)
if tag:
return tag
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
print(tags, tag_id)
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -449,16 +506,20 @@ async def add_chat_tag_by_id(
############################
@router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
@router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
if result:
return result
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -472,10 +533,17 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
if result:
return result
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

View File

@@ -213,7 +213,7 @@ async def update_file_data_content_by_id(
############################
@router.get("/{id}/content", response_model=Optional[FileModel])
@router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
@@ -223,7 +223,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
return FileResponse(file_path)
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -236,7 +239,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
@router.get("/{id}/content/{file_name}")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
@@ -248,7 +251,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
return FileResponse(file_path)
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,

View File

@@ -48,7 +48,12 @@ async def get_knowledge_items(
)
else:
return [
KnowledgeResponse(**knowledge.model_dump())
KnowledgeResponse(
**knowledge.model_dump(),
files=Files.get_file_metadatas_by_ids(
knowledge.data.get("file_ids", []) if knowledge.data else []
),
)
for knowledge in Knowledges.get_knowledge_items()
]

View File

@@ -901,6 +901,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
# Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None)
####################################
# Information Retrieval (RAG)
####################################
@@ -986,10 +989,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
"rag.embedding_openai_batch_size",
int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_BATCH_SIZE",
"rag.embedding_batch_size",
int(
os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
),
)
RAG_RERANKING_MODEL = PersistentConfig(

View File

@@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
####################################
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
####################################
# WEBUI_AUTH (Required for security)
####################################
@@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@@ -355,3 +360,9 @@ else:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
####################################
# OFFLINE_MODE
####################################
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"

View File

@@ -102,6 +102,7 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
)
from fastapi import (
Depends,
@@ -178,14 +179,14 @@ class SPAStaticFiles(StaticFiles):
print(
rf"""
___ __ __ _ _ _ ___
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|_|
|_|
v{VERSION} - building the best open-source AI user interface.
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
https://github.com/open-webui/open-webui
@@ -824,6 +825,32 @@ class PipelineMiddleware(BaseHTTPMiddleware):
app.add_middleware(PipelineMiddleware)
from urllib.parse import urlencode, parse_qs, urlparse
class RedirectMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Check if the request is a GET request
if request.method == "GET":
path = request.url.path
query_params = dict(parse_qs(urlparse(str(request.url)).query))
# Check for the specific watch path and the presence of 'v' parameter
if path.endswith("/watch") and "v" in query_params:
video_id = query_params["v"][0] # Extract the first 'v' parameter
encoded_video_id = urlencode({"youtube": video_id})
redirect_url = f"/?{encoded_video_id}"
return RedirectResponse(url=redirect_url)
# Proceed with the normal flow of other requests
response = await call_next(request)
return response
# Add the middleware to the app
app.add_middleware(RedirectMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@@ -2181,6 +2208,11 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
if OFFLINE_MODE:
log.debug(
f"Offline mode is enabled, returning current version as latest version"
)
return {"current": VERSION, "latest": VERSION}
try:
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
@@ -2353,6 +2385,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
@@ -2416,6 +2450,7 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
if os.path.exists(FRONTEND_BUILD_DIR):
mimetypes.add_type("text/javascript", ".js")
app.mount(

View File

@@ -0,0 +1,151 @@
"""Migrate tags
Revision ID: 1af9b942657b
Revises: 242a2047eae0
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "1af9b942657b"
down_revision = "242a2047eae0"
branch_labels = None
depends_on = None
def upgrade():
# Setup an inspection on the existing table to avoid issues
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Clean up potential leftover temp table from previous failures
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
# Check if the 'tag' table exists
tables = inspector.get_table_names()
# Step 1: Modify Tag table using batch mode for SQLite support
if "tag" in tables:
# Get the current columns in the 'tag' table
columns = [col["name"] for col in inspector.get_columns("tag")]
# Get any existing unique constraints on the 'tag' table
current_constraints = inspector.get_unique_constraints("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Check if the unique constraint already exists
if not any(
constraint["name"] == "uq_id_user_id"
for constraint in current_constraints
):
# Create unique constraint if it doesn't exist
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
# Check if the 'data' column exists before trying to drop it
if "data" in columns:
batch_op.drop_column("data")
# Check if the 'meta' column needs to be created
if "meta" not in columns:
# Add the 'meta' column if it doesn't already exist
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
tag = table(
"tag",
column("id", sa.String()),
column("name", sa.String()),
column("user_id", sa.String()),
column("meta", sa.JSON()),
)
# Step 2: Migrate tags
conn = op.get_bind()
result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
tag_updates = {}
for row in result:
new_id = row.name.replace(" ", "_").lower()
tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}")
if new_tag_id == "pinned":
# delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
# Check if the new_tag_id already exists in the database
existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id)
existing_tag_result = conn.execute(existing_tag_query).fetchone()
if existing_tag_result:
# Handle duplicate case: the new_tag_id already exists
print(
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
)
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
update_stmt = sa.update(tag).where(tag.c.id == tag_id)
update_stmt = update_stmt.values(id=new_tag_id)
conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
op.add_column(
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
chatidtag = table(
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chat = table(
"chat",
column("id", sa.String()),
column("pinned", sa.Boolean()),
column("meta", sa.JSON()),
)
# Fetch existing tags
conn = op.get_bind()
result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
chat_updates = {}
for row in result:
chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower()
if tag_name == "pinned":
# Specifically handle 'pinned' tag
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}}
else:
chat_updates[chat_id]["pinned"] = True
else:
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
else:
tags = chat_updates[chat_id]["meta"].get("tags", [])
tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = tags
# Update chats based on accumulated changes
for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values(
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
conn.execute(update_stmt)
pass
def downgrade():
pass

View File

@@ -0,0 +1,82 @@
"""Update chat table
Revision ID: 242a2047eae0
Revises: 6a39f3d8e55c
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update
import json
revision = "242a2047eae0"
down_revision = "6a39f3d8e55c"
branch_labels = None
depends_on = None
def upgrade():
# Step 1: Rename current 'chat' column to 'old_chat'
op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text)
# Step 2: Add new 'chat' column of type JSON
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
# Step 3: Migrate data from 'old_chat' to 'chat'
chat_table = table(
"chat",
sa.Column("id", sa.String, primary_key=True),
sa.Column("old_chat", sa.Text),
sa.Column("chat", sa.JSON()),
)
# - Selecting all data from the table
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
for row in results:
try:
# Convert text JSON to actual JSON object, assuming the text is in JSON format
json_data = json.loads(row.old_chat)
except json.JSONDecodeError:
json_data = None # Handle cases where the text cannot be converted to JSON
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(chat=json_data)
)
# Step 4: Drop 'old_chat' column
op.drop_column("chat", "old_chat")
def downgrade():
# Step 1: Add 'old_chat' column back as Text
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
chat_table = table(
"chat",
sa.Column("id", sa.String, primary_key=True),
sa.Column("chat", sa.JSON()),
sa.Column("old_chat", sa.Text()),
)
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
for row in results:
text_data = json.dumps(row.chat) if row.chat is not None else None
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(old_chat=text_data)
)
# Step 3: Remove the new 'chat' JSON column
op.drop_column("chat", "chat")
# Step 4: Rename 'old_chat' back to 'chat'
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text)

View File

@@ -1,105 +0,0 @@
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestDocuments(AbstractPostgresTest):
BASE_PATH = "/api/v1/documents"
def setup_class(cls):
super().setup_class()
from open_webui.apps.webui.models.documents import Documents
cls.documents = Documents
def test_documents(self):
# Empty database
assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a new document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name",
"title": "doc title",
"collection_name": "custom collection",
"filename": "doc_name.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs()) == 1
# Get the document
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
assert response.status_code == 200
data = response.json()
assert data["collection_name"] == "custom collection"
assert data["name"] == "doc_name"
assert data["title"] == "doc title"
assert data["filename"] == "doc_name.pdf"
assert data["content"] == {}
# Create another document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name 2",
"title": "doc title 2",
"collection_name": "custom collection 2",
"filename": "doc_name2.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs()) == 2
# Get all documents
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Update the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/update?name=doc_name"),
json={"name": "doc_name rework", "title": "updated title"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["title"] == "updated title"
# Tag the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/tags"),
json={
"name": "doc_name rework",
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
}
assert len(self.documents.get_docs()) == 2
# Delete the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/doc/delete?name=doc_name rework")
)
assert response.status_code == 200
assert len(self.documents.get_docs()) == 1

View File

@@ -12,6 +12,7 @@ passlib[bcrypt]==1.7.4
requests==2.32.3
aiohttp==3.10.8
async-timeout
sqlalchemy==2.0.32
alembic==1.13.2
@@ -41,6 +42,7 @@ langchain-chroma==0.1.4
fake-useragent==1.5.1
chromadb==0.5.9
pymilvus==2.4.7
qdrant-client~=1.12.0
sentence-transformers==3.0.1
colbert-ai==0.2.21