diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py
index 0e5672013..6398b9ee1 100644
--- a/backend/open_webui/apps/audio/main.py
+++ b/backend/open_webui/apps/audio/main.py
@@ -450,7 +450,7 @@ def transcribe(file_path):
except Exception:
error_detail = f"External: {e}"
- raise error_detail
+ raise Exception(error_detail)
@app.post("/transcriptions")
diff --git a/backend/open_webui/apps/images/utils/comfyui.py b/backend/open_webui/apps/images/utils/comfyui.py
index 0a3e3a1d9..4c421d7c5 100644
--- a/backend/open_webui/apps/images/utils/comfyui.py
+++ b/backend/open_webui/apps/images/utils/comfyui.py
@@ -125,22 +125,34 @@ async def comfyui_generate_image(
workflow[node_id]["inputs"][node.key] = model
elif node.type == "prompt":
for node_id in node.node_ids:
- workflow[node_id]["inputs"]["text"] = payload.prompt
+ workflow[node_id]["inputs"][
+ node.key if node.key else "text"
+ ] = payload.prompt
elif node.type == "negative_prompt":
for node_id in node.node_ids:
- workflow[node_id]["inputs"]["text"] = payload.negative_prompt
+ workflow[node_id]["inputs"][
+ node.key if node.key else "text"
+ ] = payload.negative_prompt
elif node.type == "width":
for node_id in node.node_ids:
- workflow[node_id]["inputs"]["width"] = payload.width
+ workflow[node_id]["inputs"][
+ node.key if node.key else "width"
+ ] = payload.width
elif node.type == "height":
for node_id in node.node_ids:
- workflow[node_id]["inputs"]["height"] = payload.height
+ workflow[node_id]["inputs"][
+ node.key if node.key else "height"
+ ] = payload.height
elif node.type == "n":
for node_id in node.node_ids:
- workflow[node_id]["inputs"]["batch_size"] = payload.n
+ workflow[node_id]["inputs"][
+ node.key if node.key else "batch_size"
+ ] = payload.n
elif node.type == "steps":
for node_id in node.node_ids:
- workflow[node_id]["inputs"]["steps"] = payload.steps
+ workflow[node_id]["inputs"][
+ node.key if node.key else "steps"
+ ] = payload.steps
elif node.type == "seed":
seed = (
payload.seed
diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py
index 33d984655..f835e3175 100644
--- a/backend/open_webui/apps/ollama/main.py
+++ b/backend/open_webui/apps/ollama/main.py
@@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel):
model: str
- input: str
- truncate: Optional[bool]
+ input: list[str]|str
+ truncate: Optional[bool] = None
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
@@ -560,48 +560,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
- if url_idx is None:
- model = form_data.model
-
- if ":" not in model:
- model = f"{model}:latest"
-
- if model in app.state.MODELS:
- url_idx = random.choice(app.state.MODELS[model]["urls"])
- else:
- raise HTTPException(
- status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
- )
-
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
- log.info(f"url: {url}")
-
- r = requests.request(
- method="POST",
- url=f"{url}/api/embed",
- headers={"Content-Type": "application/json"},
- data=form_data.model_dump_json(exclude_none=True).encode(),
- )
- try:
- r.raise_for_status()
-
- return r.json()
- except Exception as e:
- log.exception(e)
- error_detail = "Open WebUI: Server Connection Error"
- if r is not None:
- try:
- res = r.json()
- if "error" in res:
- error_detail = f"Ollama: {res['error']}"
- except Exception:
- error_detail = f"Ollama: {e}"
-
- raise HTTPException(
- status_code=r.status_code if r else 500,
- detail=error_detail,
- )
+ return generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings")
@@ -611,48 +570,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
- if url_idx is None:
- model = form_data.model
-
- if ":" not in model:
- model = f"{model}:latest"
-
- if model in app.state.MODELS:
- url_idx = random.choice(app.state.MODELS[model]["urls"])
- else:
- raise HTTPException(
- status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
- )
-
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
- log.info(f"url: {url}")
-
- r = requests.request(
- method="POST",
- url=f"{url}/api/embeddings",
- headers={"Content-Type": "application/json"},
- data=form_data.model_dump_json(exclude_none=True).encode(),
- )
- try:
- r.raise_for_status()
-
- return r.json()
- except Exception as e:
- log.exception(e)
- error_detail = "Open WebUI: Server Connection Error"
- if r is not None:
- try:
- res = r.json()
- if "error" in res:
- error_detail = f"Ollama: {res['error']}"
- except Exception:
- error_detail = f"Ollama: {e}"
-
- raise HTTPException(
- status_code=r.status_code if r else 500,
- detail=error_detail,
- )
+ return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings(
@@ -692,7 +610,64 @@ def generate_ollama_embeddings(
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
- return data["embedding"]
+ return data
+ else:
+ raise Exception("Something went wrong :/")
+ except Exception as e:
+ log.exception(e)
+ error_detail = "Open WebUI: Server Connection Error"
+ if r is not None:
+ try:
+ res = r.json()
+ if "error" in res:
+ error_detail = f"Ollama: {res['error']}"
+ except Exception:
+ error_detail = f"Ollama: {e}"
+
+ raise HTTPException(
+ status_code=r.status_code if r else 500,
+ detail=error_detail,
+ )
+
+
+def generate_ollama_batch_embeddings(
+ form_data: GenerateEmbedForm,
+ url_idx: Optional[int] = None,
+):
+ log.info(f"generate_ollama_batch_embeddings {form_data}")
+
+ if url_idx is None:
+ model = form_data.model
+
+ if ":" not in model:
+ model = f"{model}:latest"
+
+ if model in app.state.MODELS:
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+ )
+
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+ log.info(f"url: {url}")
+
+ r = requests.request(
+ method="POST",
+ url=f"{url}/api/embed",
+ headers={"Content-Type": "application/json"},
+ data=form_data.model_dump_json(exclude_none=True).encode(),
+ )
+ try:
+ r.raise_for_status()
+
+ data = r.json()
+
+ log.info(f"generate_ollama_batch_embeddings {data}")
+
+ if "embeddings" in data:
+ return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
@@ -788,8 +763,7 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
):
payload = {**form_data.model_dump(exclude_none=True)}
- log.debug(f"{payload = }")
-
+ log.debug(f"generate_chat_completion() - 1.payload = {payload}")
if "metadata" in payload:
del payload["metadata"]
@@ -824,7 +798,7 @@ async def generate_chat_completion(
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
- log.debug(payload)
+ log.debug(f"generate_chat_completion() - 2.payload = {payload}")
return await post_streaming_url(
f"{url}/api/chat",
diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py
index 52cebeabc..c80b2011d 100644
--- a/backend/open_webui/apps/retrieval/main.py
+++ b/backend/open_webui/apps/retrieval/main.py
@@ -63,7 +63,7 @@ from open_webui.config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
- RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ RAG_EMBEDDING_BATCH_SIZE,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
-app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
+app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
app.add_middleware(
@@ -267,7 +267,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
- "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
}
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
- "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
- batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
+ embedding_batch_size: Optional[int] = 1
@app.post("/embedding/update")
@@ -320,11 +320,7 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
- form_data.openai_config.batch_size
- if form_data.openai_config.batch_size
- else 1
- )
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -334,17 +330,17 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
- "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -645,7 +641,7 @@ def save_docs_to_vector_db(
filter={"hash": metadata["hash"]},
)
- if result:
+ if result is not None:
existing_doc_ids = result.ids[0]
if existing_doc_ids:
log.info(f"Document with hash {metadata['hash']} already exists")
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
embeddings = embedding_function(
@@ -767,7 +763,7 @@ def process_file(
collection_name=f"file-{file.id}", filter={"file_id": file.id}
)
- if len(result.ids[0]) > 0:
+ if result is not None and len(result.ids[0]) > 0:
docs = [
Document(
page_content=result.documents[0][idx],
diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py
index 53961be2c..4ca2db1bd 100644
--- a/backend/open_webui/apps/retrieval/utils.py
+++ b/backend/open_webui/apps/retrieval/utils.py
@@ -12,8 +12,8 @@ from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
- GenerateEmbeddingsForm,
- generate_ollama_embeddings,
+ GenerateEmbedForm,
+ generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@@ -193,7 +193,8 @@ def query_collection(
k=k,
query_embedding=query_embedding,
)
- results.append(result.model_dump())
+ if result is not None:
+ results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
@@ -265,39 +266,27 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
- batch_size,
+ embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
- if embedding_engine == "ollama":
- func = lambda query: generate_ollama_embeddings(
- GenerateEmbeddingsForm(
- **{
- "model": embedding_model,
- "prompt": query,
- }
- )
- )
- elif embedding_engine == "openai":
- func = lambda query: generate_openai_embeddings(
- model=embedding_model,
- text=query,
- key=openai_key,
- url=openai_url,
- )
+ func = lambda query: generate_embeddings(
+ engine=embedding_engine,
+ model=embedding_model,
+ text=query,
+ key=openai_key if embedding_engine == "openai" else "",
+ url=openai_url if embedding_engine == "openai" else "",
+ )
- def generate_multiple(query, f):
+ def generate_multiple(query, func):
if isinstance(query, list):
- if embedding_engine == "openai":
- embeddings = []
- for i in range(0, len(query), batch_size):
- embeddings.extend(f(query[i : i + batch_size]))
- return embeddings
- else:
- return [f(q) for q in query]
+ embeddings = []
+ for i in range(0, len(query), embedding_batch_size):
+ embeddings.extend(func(query[i : i + embedding_batch_size]))
+ return embeddings
else:
- return f(query)
+ return func(query)
return lambda query: generate_multiple(query, func)
@@ -445,20 +434,6 @@ def get_model_path(model: str, update_model: bool = False):
return model
-def generate_openai_embeddings(
- model: str,
- text: Union[str, list[str]],
- key: str,
- url: str = "https://api.openai.com/v1",
-):
- if isinstance(text, list):
- embeddings = generate_openai_batch_embeddings(model, text, key, url)
- else:
- embeddings = generate_openai_batch_embeddings(model, [text], key, url)
-
- return embeddings[0] if isinstance(text, str) else embeddings
-
-
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
@@ -482,6 +457,33 @@ def generate_openai_batch_embeddings(
return None
+def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+ if engine == "ollama":
+ if isinstance(text, list):
+ embeddings = generate_ollama_batch_embeddings(
+ GenerateEmbedForm(**{"model": model, "input": text})
+ )
+ else:
+ embeddings = generate_ollama_batch_embeddings(
+ GenerateEmbedForm(**{"model": model, "input": [text]})
+ )
+ return (
+ embeddings["embeddings"][0]
+ if isinstance(text, str)
+ else embeddings["embeddings"]
+ )
+ elif engine == "openai":
+ key = kwargs.get("key", "")
+ url = kwargs.get("url", "https://api.openai.com/v1")
+
+ if isinstance(text, list):
+ embeddings = generate_openai_batch_embeddings(model, text, key, url)
+ else:
+ embeddings = generate_openai_batch_embeddings(model, [text], key, url)
+
+ return embeddings[0] if isinstance(text, str) else embeddings
+
+
import operator
from typing import Optional, Sequence
diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py
index 1f33b1721..c7f00f5fd 100644
--- a/backend/open_webui/apps/retrieval/vector/connector.py
+++ b/backend/open_webui/apps/retrieval/vector/connector.py
@@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
+elif VECTOR_DB == "qdrant":
+ from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
+
+ VECTOR_DB_CLIENT = QdrantClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
new file mode 100644
index 000000000..70908dc33
--- /dev/null
+++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
@@ -0,0 +1,176 @@
+from typing import Optional
+
+from qdrant_client import QdrantClient as Qclient
+from qdrant_client.http.models import PointStruct
+from qdrant_client.models import models
+
+from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import QDRANT_URI
+
+NO_LIMIT = 999999999
+
+class QdrantClient:
+ def __init__(self):
+ self.collection_prefix = "open-webui"
+ self.QDRANT_URI = QDRANT_URI
+ self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
+
+ def _result_to_get_result(self, points) -> GetResult:
+ ids = []
+ documents = []
+ metadatas = []
+
+ for point in points:
+ payload = point.payload
+ ids.append(point.id)
+ documents.append(payload["text"])
+ metadatas.append(payload["metadata"])
+
+ return GetResult(
+ **{
+ "ids": [ids],
+ "documents": [documents],
+ "metadatas": [metadatas],
+ }
+ )
+
+ def _create_collection(self, collection_name: str, dimension: int):
+ collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
+ self.client.create_collection(
+ collection_name=collection_name_with_prefix,
+ vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
+ )
+
+ print(f"collection {collection_name_with_prefix} successfully created!")
+
+ def _create_collection_if_not_exists(self, collection_name, dimension):
+ if not self.has_collection(
+ collection_name=collection_name
+ ):
+ self._create_collection(
+ collection_name=collection_name, dimension=dimension
+ )
+
+ def _create_points(self, items: list[VectorItem]):
+ return [
+ PointStruct(
+ id=item["id"],
+ vector=item["vector"],
+ payload={
+ "text": item["text"],
+ "metadata": item["metadata"]
+ },
+ )
+ for item in items
+ ]
+
+ def has_collection(self, collection_name: str) -> bool:
+ return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
+
+ def delete_collection(self, collection_name: str):
+ return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
+
+ def search(
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
+ ) -> Optional[SearchResult]:
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
+ if limit is None:
+ limit = NO_LIMIT # otherwise qdrant would set limit to 10!
+
+ query_response = self.client.query_points(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ query=vectors[0],
+ limit=limit,
+ )
+ get_result = self._result_to_get_result(query_response.points)
+ return SearchResult(
+ ids=get_result.ids,
+ documents=get_result.documents,
+ metadatas=get_result.metadatas,
+ distances=[[point.score for point in query_response.points]]
+ )
+
+ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
+ # Construct the filter string for querying
+ if not self.has_collection(collection_name):
+ return None
+ try:
+ if limit is None:
+ limit = NO_LIMIT # otherwise qdrant would set limit to 10!
+
+ field_conditions = []
+ for key, value in filter.items():
+ field_conditions.append(
+ models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
+
+ points = self.client.query_points(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ query_filter=models.Filter(should=field_conditions),
+ limit=limit,
+ )
+ return self._result_to_get_result(points.points)
+ except Exception as e:
+ print(e)
+ return None
+
+ def get(self, collection_name: str) -> Optional[GetResult]:
+ # Get all the items in the collection.
+ points = self.client.query_points(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ limit=NO_LIMIT # otherwise qdrant would set limit to 10!
+ )
+ return self._result_to_get_result(points.points)
+
+ def insert(self, collection_name: str, items: list[VectorItem]):
+ # Insert the items into the collection, if the collection does not exist, it will be created.
+ self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
+ points = self._create_points(items)
+ self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
+
+ def upsert(self, collection_name: str, items: list[VectorItem]):
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
+ self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
+ points = self._create_points(items)
+ return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
+
+ def delete(
+ self,
+ collection_name: str,
+ ids: Optional[list[str]] = None,
+ filter: Optional[dict] = None,
+ ):
+ # Delete the items from the collection based on the ids.
+ field_conditions = []
+
+ if ids:
+ for id_value in ids:
+ field_conditions.append(
+ models.FieldCondition(
+ key="metadata.id",
+ match=models.MatchValue(value=id_value),
+ ),
+ ),
+ elif filter:
+ for key, value in filter.items():
+ field_conditions.append(
+ models.FieldCondition(
+ key=f"metadata.{key}",
+ match=models.MatchValue(value=value),
+ ),
+ ),
+
+ return self.client.delete(
+ collection_name=f"{self.collection_prefix}_{collection_name}",
+ points_selector=models.FilterSelector(
+ filter=models.Filter(
+ must=field_conditions
+ )
+ ),
+ )
+
+ def reset(self):
+ # Resets the database. This will delete all collections and item entries.
+ collection_names = self.client.get_collections().collections
+ for collection_name in collection_names:
+ if collection_name.name.startswith(self.collection_prefix):
+ self.client.delete_collection(collection_name=collection_name.name)
diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py
index f364dcc70..04355e997 100644
--- a/backend/open_webui/apps/webui/models/chats.py
+++ b/backend/open_webui/apps/webui/models/chats.py
@@ -4,8 +4,13 @@ import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
+
+
from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Boolean, Column, String, Text
+from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
+from sqlalchemy import or_, func, select, and_, text
+from sqlalchemy.sql import exists
####################
# Chat DB Schema
@@ -18,13 +23,16 @@ class Chat(Base):
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
- chat = Column(Text) # Save Chat JSON as Text
+ chat = Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
+ pinned = Column(Boolean, default=False, nullable=True)
+
+ meta = Column(JSON, server_default="{}")
class ChatModel(BaseModel):
@@ -33,13 +41,16 @@ class ChatModel(BaseModel):
id: str
user_id: str
title: str
- chat: str
+ chat: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
archived: bool = False
+ pinned: Optional[bool] = False
+
+ meta: dict = {}
####################
@@ -64,6 +75,8 @@ class ChatResponse(BaseModel):
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
archived: bool
+ pinned: Optional[bool] = False
+ meta: dict = {}
class ChatTitleIdResponse(BaseModel):
@@ -86,7 +99,7 @@ class ChatTable:
if "title" in form_data.chat
else "New Chat"
),
- "chat": json.dumps(form_data.chat),
+ "chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
@@ -101,14 +114,14 @@ class ChatTable:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
- chat_obj = db.get(Chat, id)
- chat_obj.chat = json.dumps(chat)
- chat_obj.title = chat["title"] if "title" in chat else "New Chat"
- chat_obj.updated_at = int(time.time())
+ chat_item = db.get(Chat, id)
+ chat_item.chat = chat
+ chat_item.title = chat["title"] if "title" in chat else "New Chat"
+ chat_item.updated_at = int(time.time())
db.commit()
- db.refresh(chat_obj)
+ db.refresh(chat_item)
- return ChatModel.model_validate(chat_obj)
+ return ChatModel.model_validate(chat_item)
except Exception:
return None
@@ -182,11 +195,24 @@ class ChatTable:
except Exception:
return None
+ def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
+ try:
+ with get_db() as db:
+ chat = db.get(Chat, id)
+ chat.pinned = not chat.pinned
+ chat.updated_at = int(time.time())
+ db.commit()
+ db.refresh(chat)
+ return ChatModel.model_validate(chat)
+ except Exception:
+ return None
+
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
+ chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
@@ -249,10 +275,10 @@ class ChatTable:
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
- if limit:
- query = query.limit(limit)
if skip:
query = query.offset(skip)
+ if limit:
+ query = query.limit(limit)
all_chats = query.all()
@@ -328,6 +354,15 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
+ def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
+ with get_db() as db:
+ all_chats = (
+ db.query(Chat)
+ .filter_by(user_id=user_id, pinned=True)
+ .order_by(Chat.updated_at.desc())
+ )
+ return [ChatModel.model_validate(chat) for chat in all_chats]
+
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
@@ -337,6 +372,207 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
+ def get_chats_by_user_id_and_search_text(
+ self,
+ user_id: str,
+ search_text: str,
+ include_archived: bool = False,
+ skip: int = 0,
+ limit: int = 60,
+ ) -> list[ChatModel]:
+ """
+ Filters chats based on a search query using Python, allowing pagination using skip and limit.
+ """
+ search_text = search_text.lower().strip()
+ if not search_text:
+ return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
+
+ with get_db() as db:
+ query = db.query(Chat).filter(Chat.user_id == user_id)
+
+ if not include_archived:
+ query = query.filter(Chat.archived == False)
+
+ query = query.order_by(Chat.updated_at.desc())
+
+ # Check if the database dialect is either 'sqlite' or 'postgresql'
+ dialect_name = db.bind.dialect.name
+ if dialect_name == "sqlite":
+ # SQLite case: using JSON1 extension for JSON searching
+ query = query.filter(
+ (
+ Chat.title.ilike(
+ f"%{search_text}%"
+ ) # Case-insensitive search in title
+ | text(
+ """
+ EXISTS (
+ SELECT 1
+ FROM json_each(Chat.chat, '$.messages') AS message
+ WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
+ )
+ """
+ )
+ ).params(search_text=search_text)
+ )
+ elif dialect_name == "postgresql":
+ # PostgreSQL relies on proper JSON query for search
+ query = query.filter(
+ (
+ Chat.title.ilike(
+ f"%{search_text}%"
+ ) # Case-insensitive search in title
+ | text(
+ """
+ EXISTS (
+ SELECT 1
+ FROM json_array_elements(Chat.chat->'messages') AS message
+ WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
+ )
+ """
+ )
+ ).params(search_text=search_text)
+ )
+ else:
+ raise NotImplementedError(
+ f"Unsupported dialect: {db.bind.dialect.name}"
+ )
+
+ # Perform pagination at the SQL level
+ all_chats = query.offset(skip).limit(limit).all()
+
+ # Validate and return chats
+ return [ChatModel.model_validate(chat) for chat in all_chats]
+
+ def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
+ with get_db() as db:
+ chat = db.get(Chat, id)
+ tags = chat.meta.get("tags", [])
+ return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
+
+ def get_chat_list_by_user_id_and_tag_name(
+ self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
+ ) -> list[ChatModel]:
+ with get_db() as db:
+ query = db.query(Chat).filter_by(user_id=user_id)
+ tag_id = tag_name.replace(" ", "_").lower()
+
+ print(db.bind.dialect.name)
+ if db.bind.dialect.name == "sqlite":
+ # SQLite JSON1 querying for tags within the meta JSON field
+ query = query.filter(
+ text(
+ f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
+ )
+ ).params(tag_id=tag_id)
+ elif db.bind.dialect.name == "postgresql":
+ # PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
+ query = query.filter(
+ text(
+ "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
+ )
+ ).params(tag_id=tag_id)
+ else:
+ raise NotImplementedError(
+ f"Unsupported dialect: {db.bind.dialect.name}"
+ )
+
+ all_chats = query.all()
+ print("all_chats", all_chats)
+ return [ChatModel.model_validate(chat) for chat in all_chats]
+
+ def add_chat_tag_by_id_and_user_id_and_tag_name(
+ self, id: str, user_id: str, tag_name: str
+ ) -> Optional[ChatModel]:
+ tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
+ if tag is None:
+ tag = Tags.insert_new_tag(tag_name, user_id)
+ try:
+ with get_db() as db:
+ chat = db.get(Chat, id)
+
+ tag_id = tag.id
+ if tag_id not in chat.meta.get("tags", []):
+ chat.meta = {
+ **chat.meta,
+ "tags": chat.meta.get("tags", []) + [tag_id],
+ }
+
+ db.commit()
+ db.refresh(chat)
+ return ChatModel.model_validate(chat)
+ except Exception:
+ return None
+
+ def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
+ with get_db() as db: # Assuming `get_db()` returns a session object
+ query = db.query(Chat).filter_by(user_id=user_id)
+
+ # Normalize the tag_name for consistency
+ tag_id = tag_name.replace(" ", "_").lower()
+
+ if db.bind.dialect.name == "sqlite":
+ # SQLite JSON1 support for querying the tags inside the `meta` JSON field
+ query = query.filter(
+ text(
+ f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
+ )
+ ).params(tag_id=tag_id)
+
+ elif db.bind.dialect.name == "postgresql":
+ # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
+ query = query.filter(
+ text(
+ "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
+ )
+ ).params(tag_id=tag_id)
+
+ else:
+ raise NotImplementedError(
+ f"Unsupported dialect: {db.bind.dialect.name}"
+ )
+
+ # Get the count of matching records
+ count = query.count()
+
+ # Debugging output for inspection
+ print(f"Count of chats for tag '{tag_name}':", count)
+
+ return count
+
+ def delete_tag_by_id_and_user_id_and_tag_name(
+ self, id: str, user_id: str, tag_name: str
+ ) -> bool:
+ try:
+ with get_db() as db:
+ chat = db.get(Chat, id)
+ tags = chat.meta.get("tags", [])
+ tag_id = tag_name.replace(" ", "_").lower()
+
+ tags = [tag for tag in tags if tag != tag_id]
+ chat.meta = {
+ **chat.meta,
+ "tags": tags,
+ }
+ db.commit()
+ return True
+ except Exception:
+ return False
+
+ def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+ try:
+ with get_db() as db:
+ chat = db.get(Chat, id)
+ chat.meta = {
+ **chat.meta,
+ "tags": [],
+ }
+ db.commit()
+
+ return True
+ except Exception:
+ return False
+
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:
diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/apps/webui/models/files.py
index f8d4cf8e8..20e0ffe6d 100644
--- a/backend/open_webui/apps/webui/models/files.py
+++ b/backend/open_webui/apps/webui/models/files.py
@@ -50,6 +50,14 @@ class FileModel(BaseModel):
####################
+class FileMeta(BaseModel):
+ name: Optional[str] = None
+ content_type: Optional[str] = None
+ size: Optional[int] = None
+
+ model_config = ConfigDict(extra="allow")
+
+
class FileModelResponse(BaseModel):
id: str
user_id: str
@@ -57,12 +65,19 @@ class FileModelResponse(BaseModel):
filename: str
data: Optional[dict] = None
- meta: dict
+ meta: FileMeta
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
+class FileMetadataResponse(BaseModel):
+ id: str
+ meta: dict
+ created_at: int # timestamp in epoch
+ updated_at: int # timestamp in epoch
+
+
class FileForm(BaseModel):
id: str
hash: Optional[str] = None
@@ -104,6 +119,19 @@ class FilesTable:
except Exception:
return None
+ def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
+ with get_db() as db:
+ try:
+ file = db.get(File, id)
+ return FileMetadataResponse(
+ id=file.id,
+ meta=file.meta,
+ created_at=file.created_at,
+ updated_at=file.updated_at,
+ )
+ except Exception:
+ return None
+
def get_files(self) -> list[FileModel]:
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
@@ -118,6 +146,21 @@ class FilesTable:
.all()
]
+ def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
+ with get_db() as db:
+ return [
+ FileMetadataResponse(
+ id=file.id,
+ meta=file.meta,
+ created_at=file.created_at,
+ updated_at=file.updated_at,
+ )
+ for file in db.query(File)
+ .filter(File.id.in_(ids))
+ .order_by(File.updated_at.desc())
+ .all()
+ ]
+
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
with get_db() as db:
return [
diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py
index 698cccda0..2423d1f84 100644
--- a/backend/open_webui/apps/webui/models/knowledge.py
+++ b/backend/open_webui/apps/webui/models/knowledge.py
@@ -6,6 +6,10 @@ import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.apps.webui.models.files import FileMetadataResponse
+
+
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel):
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
+ files: Optional[list[FileMetadataResponse | dict]] = None
+
class KnowledgeForm(BaseModel):
name: str
diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/apps/webui/models/tags.py
index 985273ff1..ef209b565 100644
--- a/backend/open_webui/apps/webui/models/tags.py
+++ b/backend/open_webui/apps/webui/models/tags.py
@@ -4,53 +4,32 @@ import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
+
+
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, Text
+from sqlalchemy import BigInteger, Column, String, JSON
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
####################
# Tag DB Schema
####################
-
-
class Tag(Base):
__tablename__ = "tag"
-
id = Column(String, primary_key=True)
name = Column(String)
user_id = Column(String)
- data = Column(Text, nullable=True)
-
-
-class ChatIdTag(Base):
- __tablename__ = "chatidtag"
-
- id = Column(String, primary_key=True)
- tag_name = Column(String)
- chat_id = Column(String)
- user_id = Column(String)
- timestamp = Column(BigInteger)
+ meta = Column(JSON, nullable=True)
class TagModel(BaseModel):
id: str
name: str
user_id: str
- data: Optional[str] = None
-
- model_config = ConfigDict(from_attributes=True)
-
-
-class ChatIdTagModel(BaseModel):
- id: str
- tag_name: str
- chat_id: str
- user_id: str
- timestamp: int
-
+ meta: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
@@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel):
####################
-class ChatIdTagForm(BaseModel):
- tag_name: str
+class TagChatIdForm(BaseModel):
+ name: str
chat_id: str
-class TagChatIdsResponse(BaseModel):
- chat_ids: list[str]
-
-
-class ChatTagsResponse(BaseModel):
- tags: list[str]
-
-
class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
- id = str(uuid.uuid4())
+ id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag(**tag.model_dump())
@@ -93,170 +64,38 @@ class TagTable:
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
+ id = name.replace(" ", "_").lower()
with get_db() as db:
- tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
+ tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception:
return None
- def add_tag_to_chat(
- self, user_id: str, form_data: ChatIdTagForm
- ) -> Optional[ChatIdTagModel]:
- tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
- if tag is None:
- tag = self.insert_new_tag(form_data.tag_name, user_id)
-
- id = str(uuid.uuid4())
- chatIdTag = ChatIdTagModel(
- **{
- "id": id,
- "user_id": user_id,
- "chat_id": form_data.chat_id,
- "tag_name": tag.name,
- "timestamp": int(time.time()),
- }
- )
- try:
- with get_db() as db:
- result = ChatIdTag(**chatIdTag.model_dump())
- db.add(result)
- db.commit()
- db.refresh(result)
- if result:
- return ChatIdTagModel.model_validate(result)
- else:
- return None
- except Exception:
- return None
-
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db:
- tag_names = [
- chat_id_tag.tag_name
- for chat_id_tag in (
- db.query(ChatIdTag)
- .filter_by(user_id=user_id)
- .order_by(ChatIdTag.timestamp.desc())
- .all()
- )
- ]
-
return [
TagModel.model_validate(tag)
- for tag in (
- db.query(Tag)
- .filter_by(user_id=user_id)
- .filter(Tag.name.in_(tag_names))
- .all()
- )
+ for tag in (db.query(Tag).filter_by(user_id=user_id).all())
]
- def get_tags_by_chat_id_and_user_id(
- self, chat_id: str, user_id: str
- ) -> list[TagModel]:
+ def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
with get_db() as db:
- tag_names = [
- chat_id_tag.tag_name
- for chat_id_tag in (
- db.query(ChatIdTag)
- .filter_by(user_id=user_id, chat_id=chat_id)
- .order_by(ChatIdTag.timestamp.desc())
- .all()
- )
- ]
-
return [
TagModel.model_validate(tag)
- for tag in (
- db.query(Tag)
- .filter_by(user_id=user_id)
- .filter(Tag.name.in_(tag_names))
- .all()
- )
+ for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
]
- def get_chat_ids_by_tag_name_and_user_id(
- self, tag_name: str, user_id: str
- ) -> list[ChatIdTagModel]:
- with get_db() as db:
- return [
- ChatIdTagModel.model_validate(chat_id_tag)
- for chat_id_tag in (
- db.query(ChatIdTag)
- .filter_by(user_id=user_id, tag_name=tag_name)
- .order_by(ChatIdTag.timestamp.desc())
- .all()
- )
- ]
-
- def count_chat_ids_by_tag_name_and_user_id(
- self, tag_name: str, user_id: str
- ) -> int:
- with get_db() as db:
- return (
- db.query(ChatIdTag)
- .filter_by(tag_name=tag_name, user_id=user_id)
- .count()
- )
-
- def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
+ def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
try:
with get_db() as db:
- res = (
- db.query(ChatIdTag)
- .filter_by(tag_name=tag_name, user_id=user_id)
- .delete()
- )
+ id = name.replace(" ", "_").lower()
+ res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}")
db.commit()
-
- tag_count = self.count_chat_ids_by_tag_name_and_user_id(
- tag_name, user_id
- )
- if tag_count == 0:
- # Remove tag item from Tag col as well
- db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
- db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
- def delete_tag_by_tag_name_and_chat_id_and_user_id(
- self, tag_name: str, chat_id: str, user_id: str
- ) -> bool:
- try:
- with get_db() as db:
- res = (
- db.query(ChatIdTag)
- .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
- .delete()
- )
- log.debug(f"res: {res}")
- db.commit()
-
- tag_count = self.count_chat_ids_by_tag_name_and_user_id(
- tag_name, user_id
- )
- if tag_count == 0:
- # Remove tag item from Tag col as well
- db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
- db.commit()
-
- return True
- except Exception as e:
- log.error(f"delete_tag: {e}")
- return False
-
- def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
- tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
-
- for tag in tags:
- self.delete_tag_by_tag_name_and_chat_id_and_user_id(
- tag.tag_name, chat_id, user_id
- )
-
- return True
-
Tags = TagTable()
diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py
index 563fc145f..e9f94ff6a 100644
--- a/backend/open_webui/apps/webui/routers/auths.py
+++ b/backend/open_webui/apps/webui/routers/auths.py
@@ -18,6 +18,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
+ WEBUI_SESSION_COOKIE_SAME_SITE,
+ WEBUI_SESSION_COOKIE_SECURE,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
@@ -27,6 +29,7 @@ from open_webui.utils.utils import (
create_api_key,
create_token,
get_admin_user,
+ get_verified_user,
get_current_user,
get_password_hash,
)
@@ -53,6 +56,8 @@ async def get_session_user(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
@@ -71,7 +76,7 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserResponse)
async def update_profile(
- form_data: UpdateProfileForm, session_user=Depends(get_current_user)
+ form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
):
if session_user:
user = Users.update_user_by_id(
@@ -166,6 +171,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
@@ -236,6 +243,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
if request.app.state.config.WEBHOOK_URL:
diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py
index ca7e95baf..6a9c26f8c 100644
--- a/backend/open_webui/apps/webui/routers/chats.py
+++ b/backend/open_webui/apps/webui/routers/chats.py
@@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
Chats,
ChatTitleIdResponse,
)
-from open_webui.apps.webui.models.tags import (
- ChatIdTagForm,
- ChatIdTagModel,
- TagModel,
- Tags,
-)
+from open_webui.apps.webui.models.tags import TagModel, Tags
+
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@@ -95,7 +91,7 @@ async def get_user_chat_list_by_user_id(
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
- return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ return ChatResponse(**chat.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
@@ -108,10 +104,46 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
############################
+@router.get("/search", response_model=list[ChatTitleIdResponse])
+async def search_user_chats(
+ text: str, page: Optional[int] = None, user=Depends(get_verified_user)
+):
+ if page is None:
+ page = 1
+
+ limit = 60
+ skip = (page - 1) * limit
+
+ return [
+ ChatTitleIdResponse(**chat.model_dump())
+ for chat in Chats.get_chats_by_user_id_and_search_text(
+ user.id, text, skip=skip, limit=limit
+ )
+ ]
+
+
+############################
+# GetPinnedChats
+############################
+
+
+@router.get("/pinned", response_model=list[ChatResponse])
+async def get_user_pinned_chats(user=Depends(get_verified_user)):
+ return [
+ ChatResponse(**chat.model_dump())
+ for chat in Chats.get_pinned_chats_by_user_id(user.id)
+ ]
+
+
+############################
+# GetChats
+############################
+
+
@router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
return [
- ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ ChatResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id(user.id)
]
@@ -124,11 +156,28 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=list[ChatResponse])
async def get_user_archived_chats(user=Depends(get_verified_user)):
return [
- ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ ChatResponse(**chat.model_dump())
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
+############################
+# GetAllTags
+############################
+
+
+@router.get("/all/tags", response_model=list[TagModel])
+async def get_all_user_tags(user=Depends(get_verified_user)):
+ try:
+ tags = Tags.get_tags_by_user_id(user.id)
+ return tags
+ except Exception as e:
+ log.exception(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+ )
+
+
############################
# GetAllChatsInDB
############################
@@ -141,10 +190,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
- return [
- ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
- for chat in Chats.get_chats()
- ]
+ return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
############################
@@ -187,7 +233,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id(share_id)
if chat:
- return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ return ChatResponse(**chat.model_dump())
+
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -199,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
############################
-class TagNameForm(BaseModel):
+class TagForm(BaseModel):
name: str
+
+
+class TagFilterForm(TagForm):
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
- form_data: TagNameForm, user=Depends(get_verified_user)
+ form_data: TagFilterForm, user=Depends(get_verified_user)
):
- chat_ids = [
- chat_id_tag.chat_id
- for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
- form_data.name, user.id
- )
- ]
-
- chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
-
+ chats = Chats.get_chat_list_by_user_id_and_tag_name(
+ user.id, form_data.name, form_data.skip, form_data.limit
+ )
if len(chats) == 0:
- Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
+ Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats
-############################
-# GetAllTags
-############################
-
-
-@router.get("/tags/all", response_model=list[TagModel])
-async def get_all_tags(user=Depends(get_verified_user)):
- try:
- tags = Tags.get_tags_by_user_id(user.id)
- return tags
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
- )
-
-
############################
# GetChatById
############################
@@ -251,7 +278,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
- return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ return ChatResponse(**chat.model_dump())
+
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -269,10 +297,9 @@ async def update_chat_by_id(
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
- updated_chat = {**json.loads(chat.chat), **form_data.chat}
-
+ updated_chat = {**chat.chat, **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
- return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -303,25 +330,57 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
return result
+############################
+# GetPinnedStatusById
+############################
+
+
+@router.get("/{id}/pinned", response_model=Optional[bool])
+async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ return chat.pinned
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
+ )
+
+
+############################
+# PinChatById
+############################
+
+
+@router.post("/{id}/pin", response_model=Optional[ChatResponse])
+async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ chat = Chats.toggle_chat_pinned_by_id(id)
+ return chat
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
+ )
+
+
############################
# CloneChat
############################
-@router.get("/{id}/clone", response_model=Optional[ChatResponse])
+@router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
- chat_body = json.loads(chat.chat)
updated_chat = {
- **chat_body,
+ **chat.chat,
"originalChatId": chat.id,
- "branchPointMessageId": chat_body["history"]["currentId"],
+ "branchPointMessageId": chat.chat["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
- return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -333,12 +392,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
############################
-@router.get("/{id}/archive", response_model=Optional[ChatResponse])
+@router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
- return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+ return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -356,9 +415,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
- return ChatResponse(
- **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
- )
+ return ChatResponse(**shared_chat.model_dump())
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat:
@@ -366,10 +423,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
+ return ChatResponse(**shared_chat.model_dump())
- return ChatResponse(
- **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
- )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -407,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
- tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
-
- if tags != None:
- return tags
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ tags = chat.meta.get("tags", [])
+ return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -422,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
############################
-@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
-async def add_chat_tag_by_id(
- id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
+@router.post("/{id}/tags", response_model=list[TagModel])
+async def add_tag_by_id_and_tag_name(
+ id: str, form_data: TagForm, user=Depends(get_verified_user)
):
- tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ tags = chat.meta.get("tags", [])
+ tag_id = form_data.name.replace(" ", "_").lower()
- if form_data.tag_name not in tags:
- tag = Tags.add_tag_to_chat(user.id, form_data)
-
- if tag:
- return tag
- else:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=ERROR_MESSAGES.NOT_FOUND,
+ print(tags, tag_id)
+ if tag_id not in tags:
+ Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
+ id, user.id, form_data.name
)
+
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ tags = chat.meta.get("tags", [])
+ return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -449,16 +506,20 @@ async def add_chat_tag_by_id(
############################
-@router.delete("/{id}/tags", response_model=Optional[bool])
-async def delete_chat_tag_by_id(
- id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
+@router.delete("/{id}/tags", response_model=list[TagModel])
+async def delete_tag_by_id_and_tag_name(
+ id: str, form_data: TagForm, user=Depends(get_verified_user)
):
- result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
- form_data.tag_name, id, user.id
- )
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
- if result:
- return result
+ if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
+ Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
+
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ tags = chat.meta.get("tags", [])
+ return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -472,10 +533,17 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
- result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ Chats.delete_all_tags_by_id_and_user_id(id, user.id)
- if result:
- return result
+ for tag in chat.meta.get("tags", []):
+ if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
+ Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ tags = chat.meta.get("tags", [])
+ return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py
index 0679ae062..8185971d1 100644
--- a/backend/open_webui/apps/webui/routers/files.py
+++ b/backend/open_webui/apps/webui/routers/files.py
@@ -213,7 +213,7 @@ async def update_file_data_content_by_id(
############################
-@router.get("/{id}/content", response_model=Optional[FileModel])
+@router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
@@ -223,7 +223,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
- return FileResponse(file_path)
+ headers = {
+ "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
+ }
+ return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -236,7 +239,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
-@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
+@router.get("/{id}/content/{file_name}")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
@@ -248,7 +251,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
- return FileResponse(file_path)
+ headers = {
+ "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
+ }
+ return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py
index a792c24fa..9cb38a821 100644
--- a/backend/open_webui/apps/webui/routers/knowledge.py
+++ b/backend/open_webui/apps/webui/routers/knowledge.py
@@ -48,7 +48,12 @@ async def get_knowledge_items(
)
else:
return [
- KnowledgeResponse(**knowledge.model_dump())
+ KnowledgeResponse(
+ **knowledge.model_dump(),
+ files=Files.get_file_metadatas_by_ids(
+ knowledge.data.get("file_ids", []) if knowledge.data else []
+ ),
+ )
for knowledge in Knowledges.get_knowledge_items()
]
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index bfc9a4ded..98d342897 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -901,6 +901,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
+# Qdrant
+QDRANT_URI = os.environ.get("QDRANT_URI", None)
+
####################################
# Information Retrieval (RAG)
####################################
@@ -986,10 +989,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
-RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
- "RAG_EMBEDDING_OPENAI_BATCH_SIZE",
- "rag.embedding_openai_batch_size",
- int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
+RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
+ "RAG_EMBEDDING_BATCH_SIZE",
+ "rag.embedding_batch_size",
+ int(
+ os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
+ or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
+ ),
)
RAG_RERANKING_MODEL = PersistentConfig(
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index fbf22d84d..0f2ecada0 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
+####################################
+# REDIS
+####################################
+
+REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
+
####################################
# WEBUI_AUTH (Required for security)
####################################
@@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
-WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")
-
+WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@@ -355,3 +360,9 @@ else:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
+
+####################################
+# OFFLINE_MODE
+####################################
+
+OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 7086a3cc9..5b819d78b 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -102,6 +102,7 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
+ OFFLINE_MODE,
)
from fastapi import (
Depends,
@@ -178,14 +179,14 @@ class SPAStaticFiles(StaticFiles):
print(
rf"""
- ___ __ __ _ _ _ ___
+ ___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
-| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
-| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
+| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
+| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
- |_|
+ |_|
+
-
v{VERSION} - building the best open-source AI user interface.
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
https://github.com/open-webui/open-webui
@@ -824,6 +825,32 @@ class PipelineMiddleware(BaseHTTPMiddleware):
app.add_middleware(PipelineMiddleware)
+from urllib.parse import urlencode, parse_qs, urlparse
+
+
+class RedirectMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request: Request, call_next):
+ # Check if the request is a GET request
+ if request.method == "GET":
+ path = request.url.path
+ query_params = dict(parse_qs(urlparse(str(request.url)).query))
+
+ # Check for the specific watch path and the presence of 'v' parameter
+ if path.endswith("/watch") and "v" in query_params:
+ video_id = query_params["v"][0] # Extract the first 'v' parameter
+ encoded_video_id = urlencode({"youtube": video_id})
+ redirect_url = f"/?{encoded_video_id}"
+ return RedirectResponse(url=redirect_url)
+
+ # Proceed with the normal flow of other requests
+ response = await call_next(request)
+ return response
+
+
+# Add the middleware to the app
+app.add_middleware(RedirectMiddleware)
+
+
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@@ -2181,6 +2208,11 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
+ if OFFLINE_MODE:
+ log.debug(
+ f"Offline mode is enabled, returning current version as latest version"
+ )
+ return {"current": VERSION, "latest": VERSION}
try:
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
@@ -2353,6 +2385,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
@@ -2416,6 +2450,7 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
+
if os.path.exists(FRONTEND_BUILD_DIR):
mimetypes.add_type("text/javascript", ".js")
app.mount(
diff --git a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py
new file mode 100644
index 000000000..9d79b5749
--- /dev/null
+++ b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py
@@ -0,0 +1,151 @@
+"""Migrate tags
+
+Revision ID: 1af9b942657b
+Revises: 242a2047eae0
+Create Date: 2024-10-09 21:02:35.241684
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.sql import table, select, update, column
+from sqlalchemy.engine.reflection import Inspector
+
+import json
+
+revision = "1af9b942657b"
+down_revision = "242a2047eae0"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # Setup an inspection on the existing table to avoid issues
+ conn = op.get_bind()
+ inspector = Inspector.from_engine(conn)
+
+ # Clean up potential leftover temp table from previous failures
+ conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
+
+ # Check if the 'tag' table exists
+ tables = inspector.get_table_names()
+
+ # Step 1: Modify Tag table using batch mode for SQLite support
+ if "tag" in tables:
+ # Get the current columns in the 'tag' table
+ columns = [col["name"] for col in inspector.get_columns("tag")]
+
+ # Get any existing unique constraints on the 'tag' table
+ current_constraints = inspector.get_unique_constraints("tag")
+
+ with op.batch_alter_table("tag", schema=None) as batch_op:
+ # Check if the unique constraint already exists
+ if not any(
+ constraint["name"] == "uq_id_user_id"
+ for constraint in current_constraints
+ ):
+ # Create unique constraint if it doesn't exist
+ batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
+
+ # Check if the 'data' column exists before trying to drop it
+ if "data" in columns:
+ batch_op.drop_column("data")
+
+ # Check if the 'meta' column needs to be created
+ if "meta" not in columns:
+ # Add the 'meta' column if it doesn't already exist
+ batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
+
+ tag = table(
+ "tag",
+ column("id", sa.String()),
+ column("name", sa.String()),
+ column("user_id", sa.String()),
+ column("meta", sa.JSON()),
+ )
+
+ # Step 2: Migrate tags
+ conn = op.get_bind()
+ result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
+
+ tag_updates = {}
+ for row in result:
+ new_id = row.name.replace(" ", "_").lower()
+ tag_updates[row.id] = new_id
+
+ for tag_id, new_tag_id in tag_updates.items():
+ print(f"Updating tag {tag_id} to {new_tag_id}")
+ if new_tag_id == "pinned":
+ # delete tag
+ delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
+ conn.execute(delete_stmt)
+ else:
+ # Check if the new_tag_id already exists in the database
+ existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id)
+ existing_tag_result = conn.execute(existing_tag_query).fetchone()
+
+ if existing_tag_result:
+ # Handle duplicate case: the new_tag_id already exists
+ print(
+ f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
+ )
+ # Option 1: Delete the current tag if an update to new_tag_id would cause duplication
+ delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
+ conn.execute(delete_stmt)
+ else:
+ update_stmt = sa.update(tag).where(tag.c.id == tag_id)
+ update_stmt = update_stmt.values(id=new_tag_id)
+ conn.execute(update_stmt)
+
+ # Add columns `pinned` and `meta` to 'chat'
+ op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
+ op.add_column(
+ "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
+ )
+
+ chatidtag = table(
+ "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
+ )
+ chat = table(
+ "chat",
+ column("id", sa.String()),
+ column("pinned", sa.Boolean()),
+ column("meta", sa.JSON()),
+ )
+
+ # Fetch existing tags
+ conn = op.get_bind()
+ result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
+
+ chat_updates = {}
+ for row in result:
+ chat_id = row.chat_id
+ tag_name = row.tag_name.replace(" ", "_").lower()
+
+ if tag_name == "pinned":
+ # Specifically handle 'pinned' tag
+ if chat_id not in chat_updates:
+ chat_updates[chat_id] = {"pinned": True, "meta": {}}
+ else:
+ chat_updates[chat_id]["pinned"] = True
+ else:
+ if chat_id not in chat_updates:
+ chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
+ else:
+ tags = chat_updates[chat_id]["meta"].get("tags", [])
+ tags.append(tag_name)
+
+ chat_updates[chat_id]["meta"]["tags"] = tags
+
+ # Update chats based on accumulated changes
+ for chat_id, updates in chat_updates.items():
+ update_stmt = sa.update(chat).where(chat.c.id == chat_id)
+ update_stmt = update_stmt.values(
+ meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
+ )
+ conn.execute(update_stmt)
+ pass
+
+
+def downgrade():
+ pass
diff --git a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py
new file mode 100644
index 000000000..596703dc2
--- /dev/null
+++ b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py
@@ -0,0 +1,82 @@
+"""Update chat table
+
+Revision ID: 242a2047eae0
+Revises: 6a39f3d8e55c
+Create Date: 2024-10-09 21:02:35.241684
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.sql import table, select, update
+
+import json
+
+revision = "242a2047eae0"
+down_revision = "6a39f3d8e55c"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # Step 1: Rename current 'chat' column to 'old_chat'
+ op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text)
+
+ # Step 2: Add new 'chat' column of type JSON
+ op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
+
+ # Step 3: Migrate data from 'old_chat' to 'chat'
+ chat_table = table(
+ "chat",
+ sa.Column("id", sa.String, primary_key=True),
+ sa.Column("old_chat", sa.Text),
+ sa.Column("chat", sa.JSON()),
+ )
+
+ # - Selecting all data from the table
+ connection = op.get_bind()
+ results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
+ for row in results:
+ try:
+ # Convert text JSON to actual JSON object, assuming the text is in JSON format
+ json_data = json.loads(row.old_chat)
+ except json.JSONDecodeError:
+ json_data = None # Handle cases where the text cannot be converted to JSON
+
+ connection.execute(
+ sa.update(chat_table)
+ .where(chat_table.c.id == row.id)
+ .values(chat=json_data)
+ )
+
+ # Step 4: Drop 'old_chat' column
+ op.drop_column("chat", "old_chat")
+
+
+def downgrade():
+ # Step 1: Add 'old_chat' column back as Text
+ op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
+
+ # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
+ chat_table = table(
+ "chat",
+ sa.Column("id", sa.String, primary_key=True),
+ sa.Column("chat", sa.JSON()),
+ sa.Column("old_chat", sa.Text()),
+ )
+
+ connection = op.get_bind()
+ results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
+ for row in results:
+ text_data = json.dumps(row.chat) if row.chat is not None else None
+ connection.execute(
+ sa.update(chat_table)
+ .where(chat_table.c.id == row.id)
+ .values(old_chat=text_data)
+ )
+
+ # Step 3: Remove the new 'chat' JSON column
+ op.drop_column("chat", "chat")
+
+ # Step 4: Rename 'old_chat' back to 'chat'
+ op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text)
diff --git a/backend/open_webui/test/apps/webui/routers/test_documents.py b/backend/open_webui/test/apps/webui/routers/test_documents.py
deleted file mode 100644
index 4d30b35e4..000000000
--- a/backend/open_webui/test/apps/webui/routers/test_documents.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from test.util.abstract_integration_test import AbstractPostgresTest
-from test.util.mock_user import mock_webui_user
-
-
-class TestDocuments(AbstractPostgresTest):
- BASE_PATH = "/api/v1/documents"
-
- def setup_class(cls):
- super().setup_class()
- from open_webui.apps.webui.models.documents import Documents
-
- cls.documents = Documents
-
- def test_documents(self):
- # Empty database
- assert len(self.documents.get_docs()) == 0
- with mock_webui_user(id="2"):
- response = self.fast_api_client.get(self.create_url("/"))
- assert response.status_code == 200
- assert len(response.json()) == 0
-
- # Create a new document
- with mock_webui_user(id="2"):
- response = self.fast_api_client.post(
- self.create_url("/create"),
- json={
- "name": "doc_name",
- "title": "doc title",
- "collection_name": "custom collection",
- "filename": "doc_name.pdf",
- "content": "",
- },
- )
- assert response.status_code == 200
- assert response.json()["name"] == "doc_name"
- assert len(self.documents.get_docs()) == 1
-
- # Get the document
- with mock_webui_user(id="2"):
- response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
- assert response.status_code == 200
- data = response.json()
- assert data["collection_name"] == "custom collection"
- assert data["name"] == "doc_name"
- assert data["title"] == "doc title"
- assert data["filename"] == "doc_name.pdf"
- assert data["content"] == {}
-
- # Create another document
- with mock_webui_user(id="2"):
- response = self.fast_api_client.post(
- self.create_url("/create"),
- json={
- "name": "doc_name 2",
- "title": "doc title 2",
- "collection_name": "custom collection 2",
- "filename": "doc_name2.pdf",
- "content": "",
- },
- )
- assert response.status_code == 200
- assert response.json()["name"] == "doc_name 2"
- assert len(self.documents.get_docs()) == 2
-
- # Get all documents
- with mock_webui_user(id="2"):
- response = self.fast_api_client.get(self.create_url("/"))
- assert response.status_code == 200
- assert len(response.json()) == 2
-
- # Update the first document
- with mock_webui_user(id="2"):
- response = self.fast_api_client.post(
- self.create_url("/doc/update?name=doc_name"),
- json={"name": "doc_name rework", "title": "updated title"},
- )
- assert response.status_code == 200
- data = response.json()
- assert data["name"] == "doc_name rework"
- assert data["title"] == "updated title"
-
- # Tag the first document
- with mock_webui_user(id="2"):
- response = self.fast_api_client.post(
- self.create_url("/doc/tags"),
- json={
- "name": "doc_name rework",
- "tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
- },
- )
- assert response.status_code == 200
- data = response.json()
- assert data["name"] == "doc_name rework"
- assert data["content"] == {
- "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
- }
- assert len(self.documents.get_docs()) == 2
-
- # Delete the first document
- with mock_webui_user(id="2"):
- response = self.fast_api_client.delete(
- self.create_url("/doc/delete?name=doc_name rework")
- )
- assert response.status_code == 200
- assert len(self.documents.get_docs()) == 1
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 80b4d541f..24c3fdaef 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -12,6 +12,7 @@ passlib[bcrypt]==1.7.4
requests==2.32.3
aiohttp==3.10.8
+async-timeout
sqlalchemy==2.0.32
alembic==1.13.2
@@ -41,6 +42,7 @@ langchain-chroma==0.1.4
fake-useragent==1.5.1
chromadb==0.5.9
pymilvus==2.4.7
+qdrant-client~=1.12.0
sentence-transformers==3.0.1
colbert-ai==0.2.21
diff --git a/cypress/e2e/documents.cy.ts b/cypress/e2e/documents.cy.ts
index 6ca14980d..b14b1de20 100644
--- a/cypress/e2e/documents.cy.ts
+++ b/cypress/e2e/documents.cy.ts
@@ -1,46 +1,2 @@
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
///