diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 0e5672013..6398b9ee1 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -450,7 +450,7 @@ def transcribe(file_path): except Exception: error_detail = f"External: {e}" - raise error_detail + raise Exception(error_detail) @app.post("/transcriptions") diff --git a/backend/open_webui/apps/images/utils/comfyui.py b/backend/open_webui/apps/images/utils/comfyui.py index 0a3e3a1d9..4c421d7c5 100644 --- a/backend/open_webui/apps/images/utils/comfyui.py +++ b/backend/open_webui/apps/images/utils/comfyui.py @@ -125,22 +125,34 @@ async def comfyui_generate_image( workflow[node_id]["inputs"][node.key] = model elif node.type == "prompt": for node_id in node.node_ids: - workflow[node_id]["inputs"]["text"] = payload.prompt + workflow[node_id]["inputs"][ + node.key if node.key else "text" + ] = payload.prompt elif node.type == "negative_prompt": for node_id in node.node_ids: - workflow[node_id]["inputs"]["text"] = payload.negative_prompt + workflow[node_id]["inputs"][ + node.key if node.key else "text" + ] = payload.negative_prompt elif node.type == "width": for node_id in node.node_ids: - workflow[node_id]["inputs"]["width"] = payload.width + workflow[node_id]["inputs"][ + node.key if node.key else "width" + ] = payload.width elif node.type == "height": for node_id in node.node_ids: - workflow[node_id]["inputs"]["height"] = payload.height + workflow[node_id]["inputs"][ + node.key if node.key else "height" + ] = payload.height elif node.type == "n": for node_id in node.node_ids: - workflow[node_id]["inputs"]["batch_size"] = payload.n + workflow[node_id]["inputs"][ + node.key if node.key else "batch_size" + ] = payload.n elif node.type == "steps": for node_id in node.node_ids: - workflow[node_id]["inputs"]["steps"] = payload.steps + workflow[node_id]["inputs"][ + node.key if node.key else "steps" + ] = payload.steps elif node.type == "seed": seed = ( payload.seed diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 33d984655..f835e3175 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel): class GenerateEmbedForm(BaseModel): model: str - input: str - truncate: Optional[bool] + input: list[str]|str + truncate: Optional[bool] = None options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None @@ -560,48 +560,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return generate_ollama_batch_embeddings(form_data, url_idx) @app.post("/api/embeddings") @@ -611,48 +570,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) def generate_ollama_embeddings( @@ -692,7 +610,64 @@ def generate_ollama_embeddings( log.info(f"generate_ollama_embeddings {data}") if "embedding" in data: - return data["embedding"] + return data + else: + raise Exception("Something went wrong :/") + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except Exception: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +def generate_ollama_batch_embeddings( + form_data: GenerateEmbedForm, + url_idx: Optional[int] = None, +): + log.info(f"generate_ollama_batch_embeddings {form_data}") + + if url_idx is None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={"Content-Type": "application/json"}, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + try: + r.raise_for_status() + + data = r.json() + + log.info(f"generate_ollama_batch_embeddings {data}") + + if "embeddings" in data: + return data else: raise Exception("Something went wrong :/") except Exception as e: @@ -788,8 +763,7 @@ async def generate_chat_completion( user=Depends(get_verified_user), ): payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"{payload = }") - + log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] @@ -824,7 +798,7 @@ async def generate_chat_completion( url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - log.debug(payload) + log.debug(f"generate_chat_completion() - 2.payload = {payload}") return await post_streaming_url( f"{url}/api/chat", diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 52cebeabc..c80b2011d 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -63,7 +63,7 @@ from open_webui.config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_OPENAI_BATCH_SIZE, + RAG_EMBEDDING_BATCH_SIZE, RAG_FILE_MAX_COUNT, RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, @@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_TEMPLATE = RAG_TEMPLATE @@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) app.add_middleware( @@ -267,7 +267,7 @@ async def get_status(): "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, } @@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)): "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, - "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, }, } @@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)): class OpenAIConfigForm(BaseModel): url: str key: str - batch_size: Optional[int] = None class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None embedding_engine: str embedding_model: str + embedding_batch_size: Optional[int] = 1 @app.post("/embedding/update") @@ -320,11 +320,7 @@ async def update_embedding_config( if form_data.openai_config is not None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = ( - form_data.openai_config.batch_size - if form_data.openai_config.batch_size - else 1 - ) + app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) @@ -334,17 +330,17 @@ async def update_embedding_config( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, - "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, }, } except Exception as e: @@ -645,7 +641,7 @@ def save_docs_to_vector_db( filter={"hash": metadata["hash"]}, ) - if result: + if result is not None: existing_doc_ids = result.ids[0] if existing_doc_ids: log.info(f"Document with hash {metadata['hash']} already exists") @@ -690,7 +686,7 @@ def save_docs_to_vector_db( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( @@ -767,7 +763,7 @@ def process_file( collection_name=f"file-{file.id}", filter={"file_id": file.id} ) - if len(result.ids[0]) > 0: + if result is not None and len(result.ids[0]) > 0: docs = [ Document( page_content=result.documents[0][idx], diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 53961be2c..4ca2db1bd 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -12,8 +12,8 @@ from langchain_core.documents import Document from open_webui.apps.ollama.main import ( - GenerateEmbeddingsForm, - generate_ollama_embeddings, + GenerateEmbedForm, + generate_ollama_batch_embeddings, ) from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message @@ -193,7 +193,8 @@ def query_collection( k=k, query_embedding=query_embedding, ) - results.append(result.model_dump()) + if result is not None: + results.append(result.model_dump()) except Exception as e: log.exception(f"Error when querying the collection: {e}") else: @@ -265,39 +266,27 @@ def get_embedding_function( embedding_function, openai_key, openai_url, - batch_size, + embedding_batch_size, ): if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - if embedding_engine == "ollama": - func = lambda query: generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": embedding_model, - "prompt": query, - } - ) - ) - elif embedding_engine == "openai": - func = lambda query: generate_openai_embeddings( - model=embedding_model, - text=query, - key=openai_key, - url=openai_url, - ) + func = lambda query: generate_embeddings( + engine=embedding_engine, + model=embedding_model, + text=query, + key=openai_key if embedding_engine == "openai" else "", + url=openai_url if embedding_engine == "openai" else "", + ) - def generate_multiple(query, f): + def generate_multiple(query, func): if isinstance(query, list): - if embedding_engine == "openai": - embeddings = [] - for i in range(0, len(query), batch_size): - embeddings.extend(f(query[i : i + batch_size])) - return embeddings - else: - return [f(q) for q in query] + embeddings = [] + for i in range(0, len(query), embedding_batch_size): + embeddings.extend(func(query[i : i + embedding_batch_size])) + return embeddings else: - return f(query) + return func(query) return lambda query: generate_multiple(query, func) @@ -445,20 +434,6 @@ def get_model_path(model: str, update_model: bool = False): return model -def generate_openai_embeddings( - model: str, - text: Union[str, list[str]], - key: str, - url: str = "https://api.openai.com/v1", -): - if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) - else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) - - return embeddings[0] if isinstance(text, str) else embeddings - - def generate_openai_batch_embeddings( model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" ) -> Optional[list[list[float]]]: @@ -482,6 +457,33 @@ def generate_openai_batch_embeddings( return None +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + if engine == "ollama": + if isinstance(text, list): + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": text}) + ) + else: + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": [text]}) + ) + return ( + embeddings["embeddings"][0] + if isinstance(text, str) + else embeddings["embeddings"] + ) + elif engine == "openai": + key = kwargs.get("key", "") + url = kwargs.get("url", "https://api.openai.com/v1") + + if isinstance(text, list): + embeddings = generate_openai_batch_embeddings(model, text, key, url) + else: + embeddings = generate_openai_batch_embeddings(model, [text], key, url) + + return embeddings[0] if isinstance(text, str) else embeddings + + import operator from typing import Optional, Sequence diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py index 1f33b1721..c7f00f5fd 100644 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -4,6 +4,10 @@ if VECTOR_DB == "milvus": from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient VECTOR_DB_CLIENT = MilvusClient() +elif VECTOR_DB == "qdrant": + from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient + + VECTOR_DB_CLIENT = QdrantClient() else: from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py new file mode 100644 index 000000000..70908dc33 --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py @@ -0,0 +1,176 @@ +from typing import Optional + +from qdrant_client import QdrantClient as Qclient +from qdrant_client.http.models import PointStruct +from qdrant_client.models import models + +from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import QDRANT_URI + +NO_LIMIT = 999999999 + +class QdrantClient: + def __init__(self): + self.collection_prefix = "open-webui" + self.QDRANT_URI = QDRANT_URI + self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None + + def _result_to_get_result(self, points) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for point in points: + payload = point.payload + ids.append(point.id) + documents.append(payload["text"]) + metadatas.append(payload["metadata"]) + + return GetResult( + **{ + "ids": [ids], + "documents": [documents], + "metadatas": [metadatas], + } + ) + + def _create_collection(self, collection_name: str, dimension: int): + collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" + self.client.create_collection( + collection_name=collection_name_with_prefix, + vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE), + ) + + print(f"collection {collection_name_with_prefix} successfully created!") + + def _create_collection_if_not_exists(self, collection_name, dimension): + if not self.has_collection( + collection_name=collection_name + ): + self._create_collection( + collection_name=collection_name, dimension=dimension + ) + + def _create_points(self, items: list[VectorItem]): + return [ + PointStruct( + id=item["id"], + vector=item["vector"], + payload={ + "text": item["text"], + "metadata": item["metadata"] + }, + ) + for item in items + ] + + def has_collection(self, collection_name: str) -> bool: + return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}") + + def delete_collection(self, collection_name: str): + return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}") + + def search( + self, collection_name: str, vectors: list[list[float | int]], limit: int + ) -> Optional[SearchResult]: + # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + if limit is None: + limit = NO_LIMIT # otherwise qdrant would set limit to 10! + + query_response = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + query=vectors[0], + limit=limit, + ) + get_result = self._result_to_get_result(query_response.points) + return SearchResult( + ids=get_result.ids, + documents=get_result.documents, + metadatas=get_result.metadatas, + distances=[[point.score for point in query_response.points]] + ) + + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + # Construct the filter string for querying + if not self.has_collection(collection_name): + return None + try: + if limit is None: + limit = NO_LIMIT # otherwise qdrant would set limit to 10! + + field_conditions = [] + for key, value in filter.items(): + field_conditions.append( + models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value))) + + points = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + query_filter=models.Filter(should=field_conditions), + limit=limit, + ) + return self._result_to_get_result(points.points) + except Exception as e: + print(e) + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + # Get all the items in the collection. + points = self.client.query_points( + collection_name=f"{self.collection_prefix}_{collection_name}", + limit=NO_LIMIT # otherwise qdrant would set limit to 10! + ) + return self._result_to_get_result(points.points) + + def insert(self, collection_name: str, items: list[VectorItem]): + # Insert the items into the collection, if the collection does not exist, it will be created. + self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + points = self._create_points(items) + self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points) + + def upsert(self, collection_name: str, items: list[VectorItem]): + # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + self._create_collection_if_not_exists(collection_name, len(items[0]["vector"])) + points = self._create_points(items) + return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) + + def delete( + self, + collection_name: str, + ids: Optional[list[str]] = None, + filter: Optional[dict] = None, + ): + # Delete the items from the collection based on the ids. + field_conditions = [] + + if ids: + for id_value in ids: + field_conditions.append( + models.FieldCondition( + key="metadata.id", + match=models.MatchValue(value=id_value), + ), + ), + elif filter: + for key, value in filter.items(): + field_conditions.append( + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ), + + return self.client.delete( + collection_name=f"{self.collection_prefix}_{collection_name}", + points_selector=models.FilterSelector( + filter=models.Filter( + must=field_conditions + ) + ), + ) + + def reset(self): + # Resets the database. This will delete all collections and item entries. + collection_names = self.client.get_collections().collections + for collection_name in collection_names: + if collection_name.name.startswith(self.collection_prefix): + self.client.delete_collection(collection_name=collection_name.name) diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index f364dcc70..04355e997 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -4,8 +4,13 @@ import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.tags import TagModel, Tag, Tags + + from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import or_, func, select, and_, text +from sqlalchemy.sql import exists #################### # Chat DB Schema @@ -18,13 +23,16 @@ class Chat(Base): id = Column(String, primary_key=True) user_id = Column(String) title = Column(Text) - chat = Column(Text) # Save Chat JSON as Text + chat = Column(JSON) created_at = Column(BigInteger) updated_at = Column(BigInteger) share_id = Column(Text, unique=True, nullable=True) archived = Column(Boolean, default=False) + pinned = Column(Boolean, default=False, nullable=True) + + meta = Column(JSON, server_default="{}") class ChatModel(BaseModel): @@ -33,13 +41,16 @@ class ChatModel(BaseModel): id: str user_id: str title: str - chat: str + chat: dict created_at: int # timestamp in epoch updated_at: int # timestamp in epoch share_id: Optional[str] = None archived: bool = False + pinned: Optional[bool] = False + + meta: dict = {} #################### @@ -64,6 +75,8 @@ class ChatResponse(BaseModel): created_at: int # timestamp in epoch share_id: Optional[str] = None # id of the chat to be shared archived: bool + pinned: Optional[bool] = False + meta: dict = {} class ChatTitleIdResponse(BaseModel): @@ -86,7 +99,7 @@ class ChatTable: if "title" in form_data.chat else "New Chat" ), - "chat": json.dumps(form_data.chat), + "chat": form_data.chat, "created_at": int(time.time()), "updated_at": int(time.time()), } @@ -101,14 +114,14 @@ class ChatTable: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: with get_db() as db: - chat_obj = db.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) + chat_item = db.get(Chat, id) + chat_item.chat = chat + chat_item.title = chat["title"] if "title" in chat else "New Chat" + chat_item.updated_at = int(time.time()) db.commit() - db.refresh(chat_obj) + db.refresh(chat_item) - return ChatModel.model_validate(chat_obj) + return ChatModel.model_validate(chat_item) except Exception: return None @@ -182,11 +195,24 @@ class ChatTable: except Exception: return None + def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.pinned = not chat.pinned + chat.updated_at = int(time.time()) + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.archived = not chat.archived + chat.updated_at = int(time.time()) db.commit() db.refresh(chat) return ChatModel.model_validate(chat) @@ -249,10 +275,10 @@ class ChatTable: Chat.id, Chat.title, Chat.updated_at, Chat.created_at ) - if limit: - query = query.limit(limit) if skip: query = query.offset(skip) + if limit: + query = query.limit(limit) all_chats = query.all() @@ -328,6 +354,15 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: + with get_db() as db: + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, pinned=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] + def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -337,6 +372,207 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chats_by_user_id_and_search_text( + self, + user_id: str, + search_text: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 60, + ) -> list[ChatModel]: + """ + Filters chats based on a search query using Python, allowing pagination using skip and limit. + """ + search_text = search_text.lower().strip() + if not search_text: + return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) + + with get_db() as db: + query = db.query(Chat).filter(Chat.user_id == user_id) + + if not include_archived: + query = query.filter(Chat.archived == False) + + query = query.order_by(Chat.updated_at.desc()) + + # Check if the database dialect is either 'sqlite' or 'postgresql' + dialect_name = db.bind.dialect.name + if dialect_name == "sqlite": + # SQLite case: using JSON1 extension for JSON searching + query = query.filter( + ( + Chat.title.ilike( + f"%{search_text}%" + ) # Case-insensitive search in title + | text( + """ + EXISTS ( + SELECT 1 + FROM json_each(Chat.chat, '$.messages') AS message + WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%' + ) + """ + ) + ).params(search_text=search_text) + ) + elif dialect_name == "postgresql": + # PostgreSQL relies on proper JSON query for search + query = query.filter( + ( + Chat.title.ilike( + f"%{search_text}%" + ) # Case-insensitive search in title + | text( + """ + EXISTS ( + SELECT 1 + FROM json_array_elements(Chat.chat->'messages') AS message + WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%' + ) + """ + ) + ).params(search_text=search_text) + ) + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + # Perform pagination at the SQL level + all_chats = query.offset(skip).limit(limit).all() + + # Validate and return chats + return [ChatModel.model_validate(chat) for chat in all_chats] + + def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: + with get_db() as db: + chat = db.get(Chat, id) + tags = chat.meta.get("tags", []) + return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] + + def get_chat_list_by_user_id_and_tag_name( + self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 + ) -> list[ChatModel]: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + tag_id = tag_name.replace(" ", "_").lower() + + print(db.bind.dialect.name) + if db.bind.dialect.name == "sqlite": + # SQLite JSON1 querying for tags within the meta JSON field + query = query.filter( + text( + f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + ) + ).params(tag_id=tag_id) + elif db.bind.dialect.name == "postgresql": + # PostgreSQL JSON query for tags within the meta JSON field (for `json` type) + query = query.filter( + text( + "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" + ) + ).params(tag_id=tag_id) + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + all_chats = query.all() + print("all_chats", all_chats) + return [ChatModel.model_validate(chat) for chat in all_chats] + + def add_chat_tag_by_id_and_user_id_and_tag_name( + self, id: str, user_id: str, tag_name: str + ) -> Optional[ChatModel]: + tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) + if tag is None: + tag = Tags.insert_new_tag(tag_name, user_id) + try: + with get_db() as db: + chat = db.get(Chat, id) + + tag_id = tag.id + if tag_id not in chat.meta.get("tags", []): + chat.meta = { + **chat.meta, + "tags": chat.meta.get("tags", []) + [tag_id], + } + + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + + def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: + with get_db() as db: # Assuming `get_db()` returns a session object + query = db.query(Chat).filter_by(user_id=user_id) + + # Normalize the tag_name for consistency + tag_id = tag_name.replace(" ", "_").lower() + + if db.bind.dialect.name == "sqlite": + # SQLite JSON1 support for querying the tags inside the `meta` JSON field + query = query.filter( + text( + f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + ) + ).params(tag_id=tag_id) + + elif db.bind.dialect.name == "postgresql": + # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field + query = query.filter( + text( + "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" + ) + ).params(tag_id=tag_id) + + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + # Get the count of matching records + count = query.count() + + # Debugging output for inspection + print(f"Count of chats for tag '{tag_name}':", count) + + return count + + def delete_tag_by_id_and_user_id_and_tag_name( + self, id: str, user_id: str, tag_name: str + ) -> bool: + try: + with get_db() as db: + chat = db.get(Chat, id) + tags = chat.meta.get("tags", []) + tag_id = tag_name.replace(" ", "_").lower() + + tags = [tag for tag in tags if tag != tag_id] + chat.meta = { + **chat.meta, + "tags": tags, + } + db.commit() + return True + except Exception: + return False + + def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.meta = { + **chat.meta, + "tags": [], + } + db.commit() + + return True + except Exception: + return False + def delete_chat_by_id(self, id: str) -> bool: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/models/files.py b/backend/open_webui/apps/webui/models/files.py index f8d4cf8e8..20e0ffe6d 100644 --- a/backend/open_webui/apps/webui/models/files.py +++ b/backend/open_webui/apps/webui/models/files.py @@ -50,6 +50,14 @@ class FileModel(BaseModel): #################### +class FileMeta(BaseModel): + name: Optional[str] = None + content_type: Optional[str] = None + size: Optional[int] = None + + model_config = ConfigDict(extra="allow") + + class FileModelResponse(BaseModel): id: str user_id: str @@ -57,12 +65,19 @@ class FileModelResponse(BaseModel): filename: str data: Optional[dict] = None - meta: dict + meta: FileMeta created_at: int # timestamp in epoch updated_at: int # timestamp in epoch +class FileMetadataResponse(BaseModel): + id: str + meta: dict + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + class FileForm(BaseModel): id: str hash: Optional[str] = None @@ -104,6 +119,19 @@ class FilesTable: except Exception: return None + def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: + with get_db() as db: + try: + file = db.get(File, id) + return FileMetadataResponse( + id=file.id, + meta=file.meta, + created_at=file.created_at, + updated_at=file.updated_at, + ) + except Exception: + return None + def get_files(self) -> list[FileModel]: with get_db() as db: return [FileModel.model_validate(file) for file in db.query(File).all()] @@ -118,6 +146,21 @@ class FilesTable: .all() ] + def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: + with get_db() as db: + return [ + FileMetadataResponse( + id=file.id, + meta=file.meta, + created_at=file.created_at, + updated_at=file.updated_at, + ) + for file in db.query(File) + .filter(File.id.in_(ids)) + .order_by(File.updated_at.desc()) + .all() + ] + def get_files_by_user_id(self, user_id: str) -> list[FileModel]: with get_db() as db: return [ diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index 698cccda0..2423d1f84 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -6,6 +6,10 @@ import uuid from open_webui.apps.webui.internal.db import Base, get_db from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.files import FileMetadataResponse + + from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel): created_at: int # timestamp in epoch updated_at: int # timestamp in epoch + files: Optional[list[FileMetadataResponse | dict]] = None + class KnowledgeForm(BaseModel): name: str diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/apps/webui/models/tags.py index 985273ff1..ef209b565 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/apps/webui/models/tags.py @@ -4,53 +4,32 @@ import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db + + from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, JSON log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) + #################### # Tag DB Schema #################### - - class Tag(Base): __tablename__ = "tag" - id = Column(String, primary_key=True) name = Column(String) user_id = Column(String) - data = Column(Text, nullable=True) - - -class ChatIdTag(Base): - __tablename__ = "chatidtag" - - id = Column(String, primary_key=True) - tag_name = Column(String) - chat_id = Column(String) - user_id = Column(String) - timestamp = Column(BigInteger) + meta = Column(JSON, nullable=True) class TagModel(BaseModel): id: str name: str user_id: str - data: Optional[str] = None - - model_config = ConfigDict(from_attributes=True) - - -class ChatIdTagModel(BaseModel): - id: str - tag_name: str - chat_id: str - user_id: str - timestamp: int - + meta: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel): #################### -class ChatIdTagForm(BaseModel): - tag_name: str +class TagChatIdForm(BaseModel): + name: str chat_id: str -class TagChatIdsResponse(BaseModel): - chat_ids: list[str] - - -class ChatTagsResponse(BaseModel): - tags: list[str] - - class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: with get_db() as db: - id = str(uuid.uuid4()) + id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: result = Tag(**tag.model_dump()) @@ -93,170 +64,38 @@ class TagTable: self, name: str, user_id: str ) -> Optional[TagModel]: try: + id = name.replace(" ", "_").lower() with get_db() as db: - tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() + tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def add_tag_to_chat( - self, user_id: str, form_data: ChatIdTagForm - ) -> Optional[ChatIdTagModel]: - tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) - if tag is None: - tag = self.insert_new_tag(form_data.tag_name, user_id) - - id = str(uuid.uuid4()) - chatIdTag = ChatIdTagModel( - **{ - "id": id, - "user_id": user_id, - "chat_id": form_data.chat_id, - "tag_name": tag.name, - "timestamp": int(time.time()), - } - ) - try: - with get_db() as db: - result = ChatIdTag(**chatIdTag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None - except Exception: - return None - def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: with get_db() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - return [ TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) + for tag in (db.query(Tag).filter_by(user_id=user_id).all()) ] - def get_tags_by_chat_id_and_user_id( - self, chat_id: str, user_id: str - ) -> list[TagModel]: + def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]: with get_db() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - return [ TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) + for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all()) ] - def get_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> list[ChatIdTagModel]: - with get_db() as db: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - - def count_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> int: - with get_db() as db: - return ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) - - def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: + def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: try: with get_db() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) + id = name.replace(" ", "_").lower() + res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") db.commit() - - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - db.commit() return True except Exception as e: log.error(f"delete_tag: {e}") return False - def delete_tag_by_tag_name_and_chat_id_and_user_id( - self, tag_name: str, chat_id: str, user_id: str - ) -> bool: - try: - with get_db() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() - - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - db.commit() - - return True - except Exception as e: - log.error(f"delete_tag: {e}") - return False - - def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: - tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) - - for tag in tags: - self.delete_tag_by_tag_name_and_chat_id_and_user_id( - tag.tag_name, chat_id, user_id - ) - - return True - Tags = TagTable() diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index 563fc145f..e9f94ff6a 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -18,6 +18,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, ) from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import Response @@ -27,6 +29,7 @@ from open_webui.utils.utils import ( create_api_key, create_token, get_admin_user, + get_verified_user, get_current_user, get_password_hash, ) @@ -53,6 +56,8 @@ async def get_session_user( key="token", value=token, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) return { @@ -71,7 +76,7 @@ async def get_session_user( @router.post("/update/profile", response_model=UserResponse) async def update_profile( - form_data: UpdateProfileForm, session_user=Depends(get_current_user) + form_data: UpdateProfileForm, session_user=Depends(get_verified_user) ): if session_user: user = Users.update_user_by_id( @@ -166,6 +171,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm): key="token", value=token, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) return { @@ -236,6 +243,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm): key="token", value=token, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) if request.app.state.config.WEBHOOK_URL: diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index ca7e95baf..6a9c26f8c 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import ( Chats, ChatTitleIdResponse, ) -from open_webui.apps.webui.models.tags import ( - ChatIdTagForm, - ChatIdTagModel, - TagModel, - Tags, -) +from open_webui.apps.webui.models.tags import TagModel, Tags + from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS @@ -95,7 +91,7 @@ async def get_user_chat_list_by_user_id( async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): try: chat = Chats.insert_new_chat(user.id, form_data) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) raise HTTPException( @@ -108,10 +104,46 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): ############################ +@router.get("/search", response_model=list[ChatTitleIdResponse]) +async def search_user_chats( + text: str, page: Optional[int] = None, user=Depends(get_verified_user) +): + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + return [ + ChatTitleIdResponse(**chat.model_dump()) + for chat in Chats.get_chats_by_user_id_and_search_text( + user.id, text, skip=skip, limit=limit + ) + ] + + +############################ +# GetPinnedChats +############################ + + +@router.get("/pinned", response_model=list[ChatResponse]) +async def get_user_pinned_chats(user=Depends(get_verified_user)): + return [ + ChatResponse(**chat.model_dump()) + for chat in Chats.get_pinned_chats_by_user_id(user.id) + ] + + +############################ +# GetChats +############################ + + @router.get("/all", response_model=list[ChatResponse]) async def get_user_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + ChatResponse(**chat.model_dump()) for chat in Chats.get_chats_by_user_id(user.id) ] @@ -124,11 +156,28 @@ async def get_user_chats(user=Depends(get_verified_user)): @router.get("/all/archived", response_model=list[ChatResponse]) async def get_user_archived_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + ChatResponse(**chat.model_dump()) for chat in Chats.get_archived_chats_by_user_id(user.id) ] +############################ +# GetAllTags +############################ + + +@router.get("/all/tags", response_model=list[TagModel]) +async def get_all_user_tags(user=Depends(get_verified_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetAllChatsInDB ############################ @@ -141,10 +190,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats() - ] + return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] ############################ @@ -187,7 +233,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id(share_id) if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -199,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): ############################ -class TagNameForm(BaseModel): +class TagForm(BaseModel): name: str + + +class TagFilterForm(TagForm): skip: Optional[int] = 0 limit: Optional[int] = 50 @router.post("/tags", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, user=Depends(get_verified_user) + form_data: TagFilterForm, user=Depends(get_verified_user) ): - chat_ids = [ - chat_id_tag.chat_id - for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( - form_data.name, user.id - ) - ] - - chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) - + chats = Chats.get_chat_list_by_user_id_and_tag_name( + user.id, form_data.name, form_data.skip, form_data.limit + ) if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) return chats -############################ -# GetAllTags -############################ - - -@router.get("/tags/all", response_model=list[TagModel]) -async def get_all_tags(user=Depends(get_verified_user)): - try: - tags = Tags.get_tags_by_user_id(user.id) - return tags - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) - - ############################ # GetChatById ############################ @@ -251,7 +278,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -269,10 +297,9 @@ async def update_chat_by_id( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - updated_chat = {**json.loads(chat.chat), **form_data.chat} - + updated_chat = {**chat.chat, **form_data.chat} chat = Chats.update_chat_by_id(id, updated_chat) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -303,25 +330,57 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified return result +############################ +# GetPinnedStatusById +############################ + + +@router.get("/{id}/pinned", response_model=Optional[bool]) +async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + return chat.pinned + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# PinChatById +############################ + + +@router.post("/{id}/pin", response_model=Optional[ChatResponse]) +async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + chat = Chats.toggle_chat_pinned_by_id(id) + return chat + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # CloneChat ############################ -@router.get("/{id}/clone", response_model=Optional[ChatResponse]) +@router.post("/{id}/clone", response_model=Optional[ChatResponse]) async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat_body = json.loads(chat.chat) updated_chat = { - **chat_body, + **chat.chat, "originalChatId": chat.id, - "branchPointMessageId": chat_body["history"]["currentId"], + "branchPointMessageId": chat.chat["history"]["currentId"], "title": f"Clone of {chat.title}", } chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -333,12 +392,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.get("/{id}/archive", response_model=Optional[ChatResponse]) +@router.post("/{id}/archive", response_model=Optional[ChatResponse]) async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -356,9 +415,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)): if chat: if chat.share_id: shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) - return ChatResponse( - **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} - ) + return ChatResponse(**shared_chat.model_dump()) shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) if not shared_chat: @@ -366,10 +423,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(), ) + return ChatResponse(**shared_chat.model_dump()) - return ChatResponse( - **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} - ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -407,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/tags", response_model=list[TagModel]) async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) - - if tags != None: - return tags + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -422,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) -async def add_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) +@router.post("/{id}/tags", response_model=list[TagModel]) +async def add_tag_by_id_and_tag_name( + id: str, form_data: TagForm, user=Depends(get_verified_user) ): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + tags = chat.meta.get("tags", []) + tag_id = form_data.name.replace(" ", "_").lower() - if form_data.tag_name not in tags: - tag = Tags.add_tag_to_chat(user.id, form_data) - - if tag: - return tag - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, + print(tags, tag_id) + if tag_id not in tags: + Chats.add_chat_tag_by_id_and_user_id_and_tag_name( + id, user.id, form_data.name ) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -449,16 +506,20 @@ async def add_chat_tag_by_id( ############################ -@router.delete("/{id}/tags", response_model=Optional[bool]) -async def delete_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) +@router.delete("/{id}/tags", response_model=list[TagModel]) +async def delete_tag_by_id_and_tag_name( + id: str, form_data: TagForm, user=Depends(get_verified_user) ): - result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( - form_data.tag_name, id, user.id - ) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) - if result: - return result + if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -472,10 +533,17 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + Chats.delete_all_tags_by_id_and_user_id(id, user.id) - if result: - return result + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index 0679ae062..8185971d1 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -213,7 +213,7 @@ async def update_file_data_content_by_id( ############################ -@router.get("/{id}/content", response_model=Optional[FileModel]) +@router.get("/{id}/content") async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) @@ -223,7 +223,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): # Check if the file already exists in the cache if file_path.is_file(): print(f"file_path: {file_path}") - return FileResponse(file_path) + headers = { + "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + } + return FileResponse(file_path, headers=headers) else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -236,7 +239,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ) -@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) +@router.get("/{id}/content/{file_name}") async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id) @@ -248,7 +251,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): # Check if the file already exists in the cache if file_path.is_file(): print(f"file_path: {file_path}") - return FileResponse(file_path) + headers = { + "Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"' + } + return FileResponse(file_path, headers=headers) else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index a792c24fa..9cb38a821 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -48,7 +48,12 @@ async def get_knowledge_items( ) else: return [ - KnowledgeResponse(**knowledge.model_dump()) + KnowledgeResponse( + **knowledge.model_dump(), + files=Files.get_file_metadatas_by_ids( + knowledge.data.get("file_ids", []) if knowledge.data else [] + ), + ) for knowledge in Knowledges.get_knowledge_items() ] diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index bfc9a4ded..98d342897 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -901,6 +901,9 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") +# Qdrant +QDRANT_URI = os.environ.get("QDRANT_URI", None) + #################################### # Information Retrieval (RAG) #################################### @@ -986,10 +989,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( - "RAG_EMBEDDING_OPENAI_BATCH_SIZE", - "rag.embedding_openai_batch_size", - int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")), +RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( + "RAG_EMBEDDING_BATCH_SIZE", + "rag.embedding_batch_size", + int( + os.environ.get("RAG_EMBEDDING_BATCH_SIZE") + or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1") + ), ) RAG_RERANKING_MODEL = PersistentConfig( diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index fbf22d84d..0f2ecada0 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = ( os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" ) +#################################### +# REDIS +#################################### + +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + #################################### # WEBUI_AUTH (Required for security) #################################### @@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = ( WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") -WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0") - +WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") @@ -355,3 +360,9 @@ else: AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) except Exception: AIOHTTP_CLIENT_TIMEOUT = 300 + +#################################### +# OFFLINE_MODE +#################################### + +OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7086a3cc9..5b819d78b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -102,6 +102,7 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SECURE, WEBUI_URL, RESET_CONFIG_ON_START, + OFFLINE_MODE, ) from fastapi import ( Depends, @@ -178,14 +179,14 @@ class SPAStaticFiles(StaticFiles): print( rf""" - ___ __ __ _ _ _ ___ + ___ __ __ _ _ _ ___ / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| -| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | -| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | +| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | +| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___| - |_| + |_| + - v{VERSION} - building the best open-source AI user interface. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} https://github.com/open-webui/open-webui @@ -824,6 +825,32 @@ class PipelineMiddleware(BaseHTTPMiddleware): app.add_middleware(PipelineMiddleware) +from urllib.parse import urlencode, parse_qs, urlparse + + +class RedirectMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Check if the request is a GET request + if request.method == "GET": + path = request.url.path + query_params = dict(parse_qs(urlparse(str(request.url)).query)) + + # Check for the specific watch path and the presence of 'v' parameter + if path.endswith("/watch") and "v" in query_params: + video_id = query_params["v"][0] # Extract the first 'v' parameter + encoded_video_id = urlencode({"youtube": video_id}) + redirect_url = f"/?{encoded_video_id}" + return RedirectResponse(url=redirect_url) + + # Proceed with the normal flow of other requests + response = await call_next(request) + return response + + +# Add the middleware to the app +app.add_middleware(RedirectMiddleware) + + app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -2181,6 +2208,11 @@ async def get_app_changelog(): @app.get("/api/version/updates") async def get_app_latest_release_version(): + if OFFLINE_MODE: + log.debug( + f"Offline mode is enabled, returning current version as latest version" + ) + return {"current": VERSION, "latest": VERSION} try: timeout = aiohttp.ClientTimeout(total=1) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: @@ -2353,6 +2385,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): key="token", value=jwt_token, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, ) # Redirect back to the frontend with the JWT token @@ -2416,6 +2450,7 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") + if os.path.exists(FRONTEND_BUILD_DIR): mimetypes.add_type("text/javascript", ".js") app.mount( diff --git a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py new file mode 100644 index 000000000..9d79b5749 --- /dev/null +++ b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py @@ -0,0 +1,151 @@ +"""Migrate tags + +Revision ID: 1af9b942657b +Revises: 242a2047eae0 +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update, column +from sqlalchemy.engine.reflection import Inspector + +import json + +revision = "1af9b942657b" +down_revision = "242a2047eae0" +branch_labels = None +depends_on = None + + +def upgrade(): + # Setup an inspection on the existing table to avoid issues + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + + # Clean up potential leftover temp table from previous failures + conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag")) + + # Check if the 'tag' table exists + tables = inspector.get_table_names() + + # Step 1: Modify Tag table using batch mode for SQLite support + if "tag" in tables: + # Get the current columns in the 'tag' table + columns = [col["name"] for col in inspector.get_columns("tag")] + + # Get any existing unique constraints on the 'tag' table + current_constraints = inspector.get_unique_constraints("tag") + + with op.batch_alter_table("tag", schema=None) as batch_op: + # Check if the unique constraint already exists + if not any( + constraint["name"] == "uq_id_user_id" + for constraint in current_constraints + ): + # Create unique constraint if it doesn't exist + batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) + + # Check if the 'data' column exists before trying to drop it + if "data" in columns: + batch_op.drop_column("data") + + # Check if the 'meta' column needs to be created + if "meta" not in columns: + # Add the 'meta' column if it doesn't already exist + batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True)) + + tag = table( + "tag", + column("id", sa.String()), + column("name", sa.String()), + column("user_id", sa.String()), + column("meta", sa.JSON()), + ) + + # Step 2: Migrate tags + conn = op.get_bind() + result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id)) + + tag_updates = {} + for row in result: + new_id = row.name.replace(" ", "_").lower() + tag_updates[row.id] = new_id + + for tag_id, new_tag_id in tag_updates.items(): + print(f"Updating tag {tag_id} to {new_tag_id}") + if new_tag_id == "pinned": + # delete tag + delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) + conn.execute(delete_stmt) + else: + # Check if the new_tag_id already exists in the database + existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id) + existing_tag_result = conn.execute(existing_tag_query).fetchone() + + if existing_tag_result: + # Handle duplicate case: the new_tag_id already exists + print( + f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates." + ) + # Option 1: Delete the current tag if an update to new_tag_id would cause duplication + delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) + conn.execute(delete_stmt) + else: + update_stmt = sa.update(tag).where(tag.c.id == tag_id) + update_stmt = update_stmt.values(id=new_tag_id) + conn.execute(update_stmt) + + # Add columns `pinned` and `meta` to 'chat' + op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True)) + op.add_column( + "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}") + ) + + chatidtag = table( + "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String()) + ) + chat = table( + "chat", + column("id", sa.String()), + column("pinned", sa.Boolean()), + column("meta", sa.JSON()), + ) + + # Fetch existing tags + conn = op.get_bind() + result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name)) + + chat_updates = {} + for row in result: + chat_id = row.chat_id + tag_name = row.tag_name.replace(" ", "_").lower() + + if tag_name == "pinned": + # Specifically handle 'pinned' tag + if chat_id not in chat_updates: + chat_updates[chat_id] = {"pinned": True, "meta": {}} + else: + chat_updates[chat_id]["pinned"] = True + else: + if chat_id not in chat_updates: + chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}} + else: + tags = chat_updates[chat_id]["meta"].get("tags", []) + tags.append(tag_name) + + chat_updates[chat_id]["meta"]["tags"] = tags + + # Update chats based on accumulated changes + for chat_id, updates in chat_updates.items(): + update_stmt = sa.update(chat).where(chat.c.id == chat_id) + update_stmt = update_stmt.values( + meta=updates.get("meta", {}), pinned=updates.get("pinned", False) + ) + conn.execute(update_stmt) + pass + + +def downgrade(): + pass diff --git a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py new file mode 100644 index 000000000..596703dc2 --- /dev/null +++ b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py @@ -0,0 +1,82 @@ +"""Update chat table + +Revision ID: 242a2047eae0 +Revises: 6a39f3d8e55c +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update + +import json + +revision = "242a2047eae0" +down_revision = "6a39f3d8e55c" +branch_labels = None +depends_on = None + + +def upgrade(): + # Step 1: Rename current 'chat' column to 'old_chat' + op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text) + + # Step 2: Add new 'chat' column of type JSON + op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) + + # Step 3: Migrate data from 'old_chat' to 'chat' + chat_table = table( + "chat", + sa.Column("id", sa.String, primary_key=True), + sa.Column("old_chat", sa.Text), + sa.Column("chat", sa.JSON()), + ) + + # - Selecting all data from the table + connection = op.get_bind() + results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat)) + for row in results: + try: + # Convert text JSON to actual JSON object, assuming the text is in JSON format + json_data = json.loads(row.old_chat) + except json.JSONDecodeError: + json_data = None # Handle cases where the text cannot be converted to JSON + + connection.execute( + sa.update(chat_table) + .where(chat_table.c.id == row.id) + .values(chat=json_data) + ) + + # Step 4: Drop 'old_chat' column + op.drop_column("chat", "old_chat") + + +def downgrade(): + # Step 1: Add 'old_chat' column back as Text + op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True)) + + # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' + chat_table = table( + "chat", + sa.Column("id", sa.String, primary_key=True), + sa.Column("chat", sa.JSON()), + sa.Column("old_chat", sa.Text()), + ) + + connection = op.get_bind() + results = connection.execute(select(chat_table.c.id, chat_table.c.chat)) + for row in results: + text_data = json.dumps(row.chat) if row.chat is not None else None + connection.execute( + sa.update(chat_table) + .where(chat_table.c.id == row.id) + .values(old_chat=text_data) + ) + + # Step 3: Remove the new 'chat' JSON column + op.drop_column("chat", "chat") + + # Step 4: Rename 'old_chat' back to 'chat' + op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text) diff --git a/backend/open_webui/test/apps/webui/routers/test_documents.py b/backend/open_webui/test/apps/webui/routers/test_documents.py deleted file mode 100644 index 4d30b35e4..000000000 --- a/backend/open_webui/test/apps/webui/routers/test_documents.py +++ /dev/null @@ -1,105 +0,0 @@ -from test.util.abstract_integration_test import AbstractPostgresTest -from test.util.mock_user import mock_webui_user - - -class TestDocuments(AbstractPostgresTest): - BASE_PATH = "/api/v1/documents" - - def setup_class(cls): - super().setup_class() - from open_webui.apps.webui.models.documents import Documents - - cls.documents = Documents - - def test_documents(self): - # Empty database - assert len(self.documents.get_docs()) == 0 - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/")) - assert response.status_code == 200 - assert len(response.json()) == 0 - - # Create a new document - with mock_webui_user(id="2"): - response = self.fast_api_client.post( - self.create_url("/create"), - json={ - "name": "doc_name", - "title": "doc title", - "collection_name": "custom collection", - "filename": "doc_name.pdf", - "content": "", - }, - ) - assert response.status_code == 200 - assert response.json()["name"] == "doc_name" - assert len(self.documents.get_docs()) == 1 - - # Get the document - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/doc?name=doc_name")) - assert response.status_code == 200 - data = response.json() - assert data["collection_name"] == "custom collection" - assert data["name"] == "doc_name" - assert data["title"] == "doc title" - assert data["filename"] == "doc_name.pdf" - assert data["content"] == {} - - # Create another document - with mock_webui_user(id="2"): - response = self.fast_api_client.post( - self.create_url("/create"), - json={ - "name": "doc_name 2", - "title": "doc title 2", - "collection_name": "custom collection 2", - "filename": "doc_name2.pdf", - "content": "", - }, - ) - assert response.status_code == 200 - assert response.json()["name"] == "doc_name 2" - assert len(self.documents.get_docs()) == 2 - - # Get all documents - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/")) - assert response.status_code == 200 - assert len(response.json()) == 2 - - # Update the first document - with mock_webui_user(id="2"): - response = self.fast_api_client.post( - self.create_url("/doc/update?name=doc_name"), - json={"name": "doc_name rework", "title": "updated title"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["name"] == "doc_name rework" - assert data["title"] == "updated title" - - # Tag the first document - with mock_webui_user(id="2"): - response = self.fast_api_client.post( - self.create_url("/doc/tags"), - json={ - "name": "doc_name rework", - "tags": [{"name": "testing-tag"}, {"name": "another-tag"}], - }, - ) - assert response.status_code == 200 - data = response.json() - assert data["name"] == "doc_name rework" - assert data["content"] == { - "tags": [{"name": "testing-tag"}, {"name": "another-tag"}] - } - assert len(self.documents.get_docs()) == 2 - - # Delete the first document - with mock_webui_user(id="2"): - response = self.fast_api_client.delete( - self.create_url("/doc/delete?name=doc_name rework") - ) - assert response.status_code == 200 - assert len(self.documents.get_docs()) == 1 diff --git a/backend/requirements.txt b/backend/requirements.txt index 80b4d541f..24c3fdaef 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,6 +12,7 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.10.8 +async-timeout sqlalchemy==2.0.32 alembic==1.13.2 @@ -41,6 +42,7 @@ langchain-chroma==0.1.4 fake-useragent==1.5.1 chromadb==0.5.9 pymilvus==2.4.7 +qdrant-client~=1.12.0 sentence-transformers==3.0.1 colbert-ai==0.2.21 diff --git a/cypress/e2e/documents.cy.ts b/cypress/e2e/documents.cy.ts index 6ca14980d..b14b1de20 100644 --- a/cypress/e2e/documents.cy.ts +++ b/cypress/e2e/documents.cy.ts @@ -1,46 +1,2 @@ // eslint-disable-next-line @typescript-eslint/triple-slash-reference /// - -describe('Documents', () => { - const timestamp = Date.now(); - - before(() => { - cy.uploadTestDocument(timestamp); - }); - - after(() => { - cy.deleteTestDocument(timestamp); - }); - - context('Admin', () => { - beforeEach(() => { - // Login as the admin user - cy.loginAdmin(); - // Visit the home page - cy.visit('/workspace/documents'); - cy.get('button').contains('#cypress-test').click(); - }); - - it('can see documents', () => { - cy.get('div').contains(`document-test-initial-${timestamp}.txt`).should('have.length', 1); - }); - - it('can see edit button', () => { - cy.get('div') - .contains(`document-test-initial-${timestamp}.txt`) - .get("button[aria-label='Edit Doc']") - .should('exist'); - }); - - it('can see delete button', () => { - cy.get('div') - .contains(`document-test-initial-${timestamp}.txt`) - .get("button[aria-label='Delete Doc']") - .should('exist'); - }); - - it('can see upload button', () => { - cy.get("button[aria-label='Add Docs']").should('exist'); - }); - }); -}); diff --git a/cypress/support/e2e.ts b/cypress/support/e2e.ts index 984788733..0b94c4787 100644 --- a/cypress/support/e2e.ts +++ b/cypress/support/e2e.ts @@ -73,50 +73,6 @@ Cypress.Commands.add('register', (name, email, password) => register(name, email Cypress.Commands.add('registerAdmin', () => registerAdmin()); Cypress.Commands.add('loginAdmin', () => loginAdmin()); -Cypress.Commands.add('uploadTestDocument', (suffix: any) => { - // Login as admin - cy.loginAdmin(); - // upload example document - cy.visit('/workspace/documents'); - // Create a document - cy.get("button[aria-label='Add Docs']").click(); - cy.readFile('cypress/data/example-doc.txt').then((text) => { - // select file - cy.get('#upload-doc-input').selectFile( - { - contents: Cypress.Buffer.from(text + Date.now()), - fileName: `document-test-initial-${suffix}.txt`, - mimeType: 'text/plain', - lastModified: Date.now() - }, - { - force: true - } - ); - // open tag input - cy.get("button[aria-label='Add Tag']").click(); - cy.get("input[placeholder='Add a tag']").type('cypress-test'); - cy.get("button[aria-label='Save Tag']").click(); - - // submit to upload - cy.get("button[type='submit']").click(); - - // wait for upload to finish - cy.get('button').contains('#cypress-test').should('exist'); - cy.get('div').contains(`document-test-initial-${suffix}.txt`).should('exist'); - }); -}); - -Cypress.Commands.add('deleteTestDocument', (suffix: any) => { - cy.loginAdmin(); - cy.visit('/workspace/documents'); - // clean up uploaded documents - cy.get('div') - .contains(`document-test-initial-${suffix}.txt`) - .find("button[aria-label='Delete Doc']") - .click(); -}); - before(() => { cy.registerAdmin(); }); diff --git a/pyproject.toml b/pyproject.toml index f7a90a5b2..46c31e4a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "requests==2.32.3", "aiohttp==3.10.8", + "async-timeout", "sqlalchemy==2.0.32", "alembic==1.13.2", diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 8f4f81aea..ff89fdf43 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -167,6 +167,44 @@ export const getAllChats = async (token: string) => { return res; }; +export const getChatListBySearchText = async (token: string, text: string, page: number = 1) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('text', text); + searchParams.append('page', `${page}`); + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/search?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); +}; + export const getAllArchivedChats = async (token: string) => { let error = null; @@ -232,7 +270,7 @@ export const getAllUserChats = async (token: string) => { export const getAllChatTags = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, { method: 'GET', headers: { Accept: 'application/json', @@ -260,6 +298,40 @@ export const getAllChatTags = async (token: string) => { return res; }; +export const getPinnedChatList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/pinned`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); +}; + export const getChatListByTagName = async (token: string = '', tagName: string) => { let error = null; @@ -361,11 +433,87 @@ export const getChatByShareId = async (token: string, share_id: string) => { return res; }; +export const getChatPinnedStatusById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pinned`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const toggleChatPinnedStatusById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pin`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const cloneChatById = async (token: string, id: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, { - method: 'GET', + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', @@ -435,7 +583,7 @@ export const archiveChatById = async (token: string, id: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, { - method: 'GET', + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', @@ -605,8 +753,7 @@ export const addTagById = async (token: string, id: string, tagName: string) => ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - tag_name: tagName, - chat_id: id + name: tagName }) }) .then(async (res) => { @@ -641,8 +788,7 @@ export const deleteTagById = async (token: string, id: string, tagName: string) ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - tag_name: tagName, - chat_id: id + name: tagName }) }) .then(async (res) => { diff --git a/src/lib/apis/documents/index.ts b/src/lib/apis/documents/index.ts deleted file mode 100644 index 9d42feb19..000000000 --- a/src/lib/apis/documents/index.ts +++ /dev/null @@ -1,232 +0,0 @@ -import { WEBUI_API_BASE_URL } from '$lib/constants'; - -export const createNewDoc = async ( - token: string, - collection_name: string, - filename: string, - name: string, - title: string, - content: object | null = null -) => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/documents/create`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - collection_name: collection_name, - filename: filename, - name: name, - title: title, - ...(content ? { content: JSON.stringify(content) } : {}) - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const getDocs = async (token: string = '') => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/documents/`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const getDocByName = async (token: string, name: string) => { - let error = null; - - const searchParams = new URLSearchParams(); - searchParams.append('name', name); - - const res = await fetch(`${WEBUI_API_BASE_URL}/documents/docs?${searchParams.toString()}`, { - method: 'GET', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err.detail; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -type DocUpdateForm = { - name: string; - title: string; -}; - -export const updateDocByName = async (token: string, name: string, form: DocUpdateForm) => { - let error = null; - - const searchParams = new URLSearchParams(); - searchParams.append('name', name); - - const res = await fetch(`${WEBUI_API_BASE_URL}/documents/doc/update?${searchParams.toString()}`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - name: form.name, - title: form.title - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err.detail; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -type TagDocForm = { - name: string; - tags: string[]; -}; - -export const tagDocByName = async (token: string, name: string, form: TagDocForm) => { - let error = null; - - const searchParams = new URLSearchParams(); - searchParams.append('name', name); - - const res = await fetch(`${WEBUI_API_BASE_URL}/documents/doc/tags?${searchParams.toString()}`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - name: form.name, - tags: form.tags - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err.detail; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const deleteDocByName = async (token: string, name: string) => { - let error = null; - - const searchParams = new URLSearchParams(); - searchParams.append('name', name); - - const res = await fetch(`${WEBUI_API_BASE_URL}/documents/doc/delete?${searchParams.toString()}`, { - method: 'DELETE', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err.detail; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 9f49e9c0f..6c6b18b9f 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => { type OpenAIConfigForm = { key: string; url: string; - batch_size: number; }; type EmbeddingModelUpdateForm = { openai_config?: OpenAIConfigForm; embedding_engine: string; embedding_model: string; + embedding_batch_size?: number; }; export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { diff --git a/src/lib/components/AddFilesPlaceholder.svelte b/src/lib/components/AddFilesPlaceholder.svelte index a3057c560..d3d700795 100644 --- a/src/lib/components/AddFilesPlaceholder.svelte +++ b/src/lib/components/AddFilesPlaceholder.svelte @@ -2,20 +2,27 @@ import { getContext } from 'svelte'; export let title = ''; + export let content = ''; const i18n = getContext('i18n'); -
📄
-
- {#if title} - {title} - {:else} - {$i18n.t('Add Files')} - {/if} -
- -
- {$i18n.t('Drop any files here to add to the conversation')} +
+
📄
+
+ {#if title} + {title} + {:else} + {$i18n.t('Add Files')} + {/if}
- + +
+ {#if content} + {content} + {:else} + {$i18n.t('Drop any files here to add to the conversation')} + {/if} +
+
+
diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index d6f7dc987..d94146c7d 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -38,6 +38,7 @@ let embeddingEngine = ''; let embeddingModel = ''; + let embeddingBatchSize = 1; let rerankingModel = ''; let fileMaxSize = null; @@ -53,7 +54,6 @@ let OpenAIKey = ''; let OpenAIUrl = ''; - let OpenAIBatchSize = 1; let querySettings = { template: '', @@ -100,12 +100,16 @@ const res = await updateEmbeddingConfig(localStorage.token, { embedding_engine: embeddingEngine, embedding_model: embeddingModel, + ...(embeddingEngine === 'openai' || embeddingEngine === 'ollama' + ? { + embedding_batch_size: embeddingBatchSize + } + : {}), ...(embeddingEngine === 'openai' ? { openai_config: { key: OpenAIKey, - url: OpenAIUrl, - batch_size: OpenAIBatchSize + url: OpenAIUrl } } : {}) @@ -193,10 +197,10 @@ if (embeddingConfig) { embeddingEngine = embeddingConfig.embedding_engine; embeddingModel = embeddingConfig.embedding_model; + embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1; OpenAIKey = embeddingConfig.openai_config.key; OpenAIUrl = embeddingConfig.openai_config.url; - OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1; } }; @@ -309,6 +313,8 @@
+ {/if} + {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
{$i18n.t('Embedding Batch Size')}
@@ -318,13 +324,13 @@ min="1" max="2048" step="1" - bind:value={OpenAIBatchSize} + bind:value={embeddingBatchSize} class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700" />
{/if} +
+ +
+
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 54e14d984..83e3b967f 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -53,7 +53,7 @@ updateChatById } from '$lib/apis/chats'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; - import { processWebSearch } from '$lib/apis/retrieval'; + import { processWeb, processWebSearch, processYoutubeVideo } from '$lib/apis/retrieval'; import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; @@ -78,6 +78,7 @@ let loaded = false; const eventTarget = new EventTarget(); let controlPane; + let controlPaneComponent; let stopResponseFlag = false; let autoScroll = true; @@ -199,6 +200,20 @@ eventConfirmationTitle = data.title; eventConfirmationMessage = data.message; + } else if (type === 'execute') { + eventCallback = cb; + + try { + // Use Function constructor to evaluate code in a safer way + const asyncFunction = new Function(`return (async () => { ${data.code} })()`); + const result = await asyncFunction(); // Await the result of the async function + + if (cb) { + cb(result); + } + } catch (error) { + console.error('Error executing code:', error); + } } else if (type === 'input') { eventCallback = cb; @@ -276,14 +291,9 @@ if (controlPane && !$mobile) { try { if (value) { - const currentSize = controlPane.getSize(); - - if (currentSize === 0) { - const size = parseInt(localStorage?.chatControlsSize ?? '30'); - controlPane.resize(size ? size : 30); - } + controlPaneComponent.openPane(); } else { - controlPane.resize(0); + controlPane.collapse(); } } catch (e) { // ignore @@ -293,6 +303,7 @@ if (!value) { showCallOverlay.set(false); showOverview.set(false); + showArtifacts.set(false); } }); @@ -308,6 +319,74 @@ $socket?.off('chat-events'); }); + // File upload functions + + const uploadWeb = async (url) => { + console.log(url); + + const fileItem = { + type: 'doc', + name: url, + collection_name: '', + status: 'uploading', + url: url, + error: '' + }; + + try { + files = [...files, fileItem]; + const res = await processWeb(localStorage.token, '', url); + + if (res) { + fileItem.status = 'uploaded'; + fileItem.collection_name = res.collection_name; + fileItem.file = { + ...res.file, + ...fileItem.file + }; + + files = files; + } + } catch (e) { + // Remove the failed doc from the files array + files = files.filter((f) => f.name !== url); + toast.error(JSON.stringify(e)); + } + }; + + const uploadYoutubeTranscription = async (url) => { + console.log(url); + + const fileItem = { + type: 'doc', + name: url, + collection_name: '', + status: 'uploading', + context: 'full', + url: url, + error: '' + }; + + try { + files = [...files, fileItem]; + const res = await processYoutubeVideo(localStorage.token, url); + + if (res) { + fileItem.status = 'uploaded'; + fileItem.collection_name = res.collection_name; + fileItem.file = { + ...res.file, + ...fileItem.file + }; + files = files; + } + } catch (e) { + // Remove the failed doc from the files array + files = files.filter((f) => f.name !== url); + toast.error(e); + } + }; + ////////////////////////// // Web functions ////////////////////////// @@ -345,7 +424,17 @@ console.log($config?.default_models.split(',') ?? ''); selectedModels = $config?.default_models.split(','); } else { - selectedModels = ['']; + if ($models.length > 0) { + selectedModels = [$models[0].id]; + } else { + selectedModels = ['']; + } + } + + if ($page.url.searchParams.get('youtube')) { + uploadYoutubeTranscription( + `https://www.youtube.com/watch?v=${$page.url.searchParams.get('youtube')}` + ); } if ($page.url.searchParams.get('web-search') === 'true') { @@ -366,6 +455,11 @@ .filter((id) => id); } + if ($page.url.searchParams.get('call') === 'true') { + showCallOverlay.set(true); + showControls.set(true); + } + if ($page.url.searchParams.get('q')) { prompt = $page.url.searchParams.get('q') ?? ''; @@ -375,11 +469,6 @@ } } - if ($page.url.searchParams.get('call') === 'true') { - showCallOverlay.set(true); - showControls.set(true); - } - selectedModels = selectedModels.map((modelId) => $models.map((m) => m.id).includes(modelId) ? modelId : '' ); @@ -1855,6 +1944,7 @@ system: $settings.system ?? undefined, params: params, history: history, + messages: createMessagesList(history.currentId), tags: [], timestamp: Date.now() }); @@ -1920,6 +2010,7 @@ class="h-screen max-h-[100dvh] {$showSidebar ? 'md:max-w-[calc(100%-260px)]' : ''} w-full max-w-full flex flex-col" + id="chat-container" > {#if $settings?.backgroundImageUrl ?? null}
{ + const { type, data } = e.detail; + + if (type === 'web') { + await uploadWeb(data); + } else if (type === 'youtube') { + await uploadYoutubeTranscription(data); + } + }} on:submit={async (e) => { if (e.detail) { prompt = ''; @@ -2066,38 +2176,50 @@
{:else} - { - const model = $models.find((m) => m.id === e); - if (model?.info?.meta?.toolIds ?? false) { - return [...new Set([...a, ...model.info.meta.toolIds])]; - } - return a; - }, [])} - transparentBackground={$settings?.backgroundImageUrl ?? false} - {stopResponse} - {createMessagePair} - on:submit={async (e) => { - if (e.detail) { - prompt = ''; - await tick(); - submitPrompt(e.detail); - } - }} - /> +
+ { + const model = $models.find((m) => m.id === e); + if (model?.info?.meta?.toolIds ?? false) { + return [...new Set([...a, ...model.info.meta.toolIds])]; + } + return a; + }, [])} + transparentBackground={$settings?.backgroundImageUrl ?? false} + {stopResponse} + {createMessagePair} + on:upload={async (e) => { + const { type, data } = e.detail; + + if (type === 'web') { + await uploadWeb(data); + } else if (type === 'youtube') { + await uploadYoutubeTranscription(data); + } + }} + on:submit={async (e) => { + if (e.detail) { + prompt = ''; + await tick(); + submitPrompt(e.detail); + } + }} + /> +
{/if} import { SvelteFlowProvider } from '@xyflow/svelte'; import { slide } from 'svelte/transition'; + import { Pane, PaneResizer } from 'paneforge'; import { onDestroy, onMount, tick } from 'svelte'; import { mobile, showControls, showCallOverlay, showOverview, showArtifacts } from '$lib/stores'; @@ -10,9 +11,9 @@ import CallOverlay from './MessageInput/CallOverlay.svelte'; import Drawer from '../common/Drawer.svelte'; import Overview from './Overview.svelte'; - import { Pane, PaneResizer } from 'paneforge'; import EllipsisVertical from '../icons/EllipsisVertical.svelte'; import Artifacts from './Artifacts.svelte'; + import { min } from '@floating-ui/utils'; export let history; export let models = []; @@ -35,6 +36,16 @@ let largeScreen = false; let dragged = false; + let minSize = 0; + + export const openPane = () => { + if (parseInt(localStorage?.chatControlsSize)) { + pane.resize(parseInt(localStorage?.chatControlsSize)); + } else { + pane.resize(minSize); + } + }; + const handleMediaQuery = async (e) => { if (e.matches) { largeScreen = true; @@ -71,6 +82,32 @@ mediaQuery.addEventListener('change', handleMediaQuery); handleMediaQuery(mediaQuery); + // Select the container element you want to observe + const container = document.getElementById('chat-container'); + + // initialize the minSize based on the container width + minSize = Math.floor((350 / container.clientWidth) * 100); + + // Create a new ResizeObserver instance + const resizeObserver = new ResizeObserver((entries) => { + for (let entry of entries) { + const width = entry.contentRect.width; + // calculate the percentage of 200px + const percentage = (350 / width) * 100; + // set the minSize to the percentage, must be an integer + minSize = Math.floor(percentage); + + if ($showControls) { + if (pane && pane.isExpanded() && pane.getSize() < minSize) { + pane.resize(minSize); + } + } + } + }); + + // Start observing the container's size changes + resizeObserver.observe(container); + document.addEventListener('mousedown', onMouseDown); document.addEventListener('mouseup', onMouseUp); }); @@ -163,23 +200,29 @@ {/if} + { - if (size === 0) { - showControls.set(false); - } else { - if (!$showControls) { - showControls.set(true); + console.log('size', size, minSize); + + if ($showControls && pane.isExpanded()) { + if (size < minSize) { + pane.resize(minSize); + } + + if (size < minSize) { + localStorage.chatControlsSize = 0; + } else { + localStorage.chatControlsSize = size; } - localStorage.chatControlsSize = size; } }} + onCollapse={() => { + showControls.set(false); + }} + collapsible={true} class="pt-8" > {#if $showControls} @@ -187,7 +230,7 @@
{#if $showCallOverlay}
diff --git a/src/lib/components/chat/Controls/Controls.svelte b/src/lib/components/chat/Controls/Controls.svelte index 25924535a..aba1e7374 100644 --- a/src/lib/components/chat/Controls/Controls.svelte +++ b/src/lib/components/chat/Controls/Controls.svelte @@ -16,7 +16,7 @@
-
+
{$i18n.t('Chat Controls')}
-
+
{#if chatFiles.length > 0}
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index b0991914f..dc4fcdf55 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -1,6 +1,6 @@ @@ -300,6 +302,9 @@ bind:this={commandsElement} bind:prompt bind:files + on:upload={(e) => { + dispatch('upload', e.detail); + }} on:select={(e) => { const data = e.detail; @@ -791,25 +796,27 @@ {/if} {:else}
- + + + + +
{/if}
diff --git a/src/lib/components/chat/MessageInput/Commands.svelte b/src/lib/components/chat/MessageInput/Commands.svelte index 9ce943364..183c17fed 100644 --- a/src/lib/components/chat/MessageInput/Commands.svelte +++ b/src/lib/components/chat/MessageInput/Commands.svelte @@ -26,71 +26,6 @@ let command = ''; $: command = (prompt?.trim() ?? '').split(' ')?.at(-1) ?? ''; - - const uploadWeb = async (url) => { - console.log(url); - - const fileItem = { - type: 'doc', - name: url, - collection_name: '', - status: 'uploading', - url: url, - error: '' - }; - - try { - files = [...files, fileItem]; - const res = await processWeb(localStorage.token, '', url); - - if (res) { - fileItem.status = 'uploaded'; - fileItem.collection_name = res.collection_name; - fileItem.file = { - ...res.file, - ...fileItem.file - }; - - files = files; - } - } catch (e) { - // Remove the failed doc from the files array - files = files.filter((f) => f.name !== url); - toast.error(JSON.stringify(e)); - } - }; - - const uploadYoutubeTranscription = async (url) => { - console.log(url); - - const fileItem = { - type: 'doc', - name: url, - collection_name: '', - status: 'uploading', - url: url, - error: '' - }; - - try { - files = [...files, fileItem]; - const res = await processYoutubeVideo(localStorage.token, url); - - if (res) { - fileItem.status = 'uploaded'; - fileItem.collection_name = res.collection_name; - fileItem.file = { - ...res.file, - ...fileItem.file - }; - files = files; - } - } catch (e) { - // Remove the failed doc from the files array - files = files.filter((f) => f.name !== url); - toast.error(e); - } - }; {#if ['/', '#', '@'].includes(command?.charAt(0))} @@ -103,18 +38,23 @@ {command} on:youtube={(e) => { console.log(e); - uploadYoutubeTranscription(e.detail); + dispatch('upload', { + type: 'youtube', + data: e.detail + }); }} on:url={(e) => { console.log(e); - uploadWeb(e.detail); + dispatch('upload', { + type: 'web', + data: e.detail + }); }} on:select={(e) => { console.log(e); files = [ ...files, { - type: e?.detail?.meta?.document ? 'file' : 'collection', ...e.detail, status: 'processed' } diff --git a/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte b/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte index edce2b7e9..031d93766 100644 --- a/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Knowledge.svelte @@ -2,6 +2,10 @@ import { toast } from 'svelte-sonner'; import Fuse from 'fuse.js'; + import dayjs from 'dayjs'; + import relativeTime from 'dayjs/plugin/relativeTime'; + dayjs.extend(relativeTime); + import { createEventDispatcher, tick, getContext, onMount } from 'svelte'; import { removeLastWordFromString, isValidHttpUrl } from '$lib/utils'; import { knowledge } from '$lib/stores'; @@ -72,7 +76,13 @@ }; onMount(() => { - let legacy_documents = $knowledge.filter((item) => item?.meta?.document); + let legacy_documents = $knowledge + .filter((item) => item?.meta?.document) + .map((item) => ({ + ...item, + type: 'file' + })); + let legacy_collections = legacy_documents.length > 0 ? [ @@ -101,12 +111,44 @@ ] : []; - items = [...$knowledge, ...legacy_collections].map((item) => { - return { + let collections = $knowledge + .filter((item) => !item?.meta?.document) + .map((item) => ({ ...item, - ...(item?.legacy || item?.meta?.legacy || item?.meta?.document ? { legacy: true } : {}) - }; - }); + type: 'collection' + })); + let collection_files = + $knowledge.length > 0 + ? [ + ...$knowledge + .reduce((a, item) => { + return [ + ...new Set([ + ...a, + ...(item?.files ?? []).map((file) => ({ + ...file, + collection: { name: item.name, description: item.description } + })) + ]) + ]; + }, []) + .map((file) => ({ + ...file, + name: file?.meta?.name, + description: `${file?.collection?.name} - ${file?.collection?.description}`, + type: 'file' + })) + ] + : []; + + items = [...collections, ...collection_files, ...legacy_collections, ...legacy_documents].map( + (item) => { + return { + ...item, + ...(item?.legacy || item?.meta?.legacy || item?.meta?.document ? { legacy: true } : {}) + }; + } + ); fuse = new Fuse(items, { keys: ['name', 'description'] @@ -117,20 +159,17 @@ {#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
-
-
-
#
-
- +
{#each filteredItems as item, idx} + + {/each} {#if prompt diff --git a/src/lib/components/chat/MessageInput/Commands/Models.svelte b/src/lib/components/chat/MessageInput/Commands/Models.svelte index 768c51421..d1da6adfb 100644 --- a/src/lib/components/chat/MessageInput/Commands/Models.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Models.svelte @@ -68,15 +68,11 @@ {#if filteredItems.length > 0}
-
-
-
@
-
- +
{#each filteredItems as model, modelIdx} diff --git a/src/lib/components/chat/MessageInput/Commands/Prompts.svelte b/src/lib/components/chat/MessageInput/Commands/Prompts.svelte index 2a4b46d5d..4d7d8bd2d 100644 --- a/src/lib/components/chat/MessageInput/Commands/Prompts.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Prompts.svelte @@ -132,17 +132,13 @@ {#if filteredPrompts.length > 0}
-
-
-
/
-
- +
-
+
{#each filteredPrompts as prompt, promptIdx} +
{$i18n.t('Chat Overview')}
+
diff --git a/src/lib/components/chat/Placeholder.svelte b/src/lib/components/chat/Placeholder.svelte index e8e84544c..baf433c42 100644 --- a/src/lib/components/chat/Placeholder.svelte +++ b/src/lib/components/chat/Placeholder.svelte @@ -89,7 +89,7 @@ {#key mounted} -
+
{#if $temporaryChatEnabled}
@@ -204,6 +204,9 @@ {stopResponse} {createMessagePair} placeholder={$i18n.t('How can I help you today?')} + on:upload={(e) => { + dispatch('upload', e.detail); + }} on:submit={(e) => { dispatch('submit', e.detail); }} diff --git a/src/lib/components/chat/Tags.svelte b/src/lib/components/chat/Tags.svelte index e6d01b3b5..7c5a0c0c1 100644 --- a/src/lib/components/chat/Tags.svelte +++ b/src/lib/components/chat/Tags.svelte @@ -25,51 +25,32 @@ let tags = []; const getTags = async () => { - return ( - await getTagsById(localStorage.token, chatId).catch(async (error) => { - return []; - }) - ).filter((tag) => tag.name !== 'pinned'); + return await getTagsById(localStorage.token, chatId).catch(async (error) => { + return []; + }); }; const addTag = async (tagName) => { const res = await addTagById(localStorage.token, chatId, tagName); tags = await getTags(); - await updateChatById(localStorage.token, chatId, { tags: tags }); _tags.set(await getAllChatTags(localStorage.token)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); }; const deleteTag = async (tagName) => { const res = await deleteTagById(localStorage.token, chatId, tagName); tags = await getTags(); - await updateChatById(localStorage.token, chatId, { tags: tags }); await _tags.set(await getAllChatTags(localStorage.token)); - if ($_tags.map((t) => t.name).includes(tagName)) { - if (tagName === 'pinned') { - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); - } else { - await chats.set(await getChatListByTagName(localStorage.token, tagName)); - } - - if ($chats.find((chat) => chat.id === chatId)) { - dispatch('close'); - } - } else { - // if the tag we deleted is no longer a valid tag, return to main chat list view - currentChatPage.set(1); - await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); - await scrollPaginationEnabled.set(true); - } + dispatch('delete', { + name: tagName + }); }; onMount(async () => { diff --git a/src/lib/components/common/SVGPanZoom.svelte b/src/lib/components/common/SVGPanZoom.svelte index 549fd2500..e576ffb06 100644 --- a/src/lib/components/common/SVGPanZoom.svelte +++ b/src/lib/components/common/SVGPanZoom.svelte @@ -1,11 +1,19 @@ -
+
{@html svg}
+ + {#if content} +
+ + + +
+ {/if}
diff --git a/src/lib/components/documents/AddDocModal.svelte b/src/lib/components/documents/AddDocModal.svelte deleted file mode 100644 index 8c4d478f7..000000000 --- a/src/lib/components/documents/AddDocModal.svelte +++ /dev/null @@ -1,166 +0,0 @@ - - - -
-
-
{$i18n.t('Add Docs')}
- -
-
-
-
{ - submitHandler(); - }} - > -
- - - -
- -
-
-
{$i18n.t('Tags')}
- - -
-
- -
- -
-
-
-
-
-
- - diff --git a/src/lib/components/documents/EditDocModal.svelte b/src/lib/components/documents/EditDocModal.svelte deleted file mode 100644 index 47577bd48..000000000 --- a/src/lib/components/documents/EditDocModal.svelte +++ /dev/null @@ -1,181 +0,0 @@ - - - -
-
-
{$i18n.t('Edit Doc')}
- -
-
-
-
{ - submitHandler(); - }} - > -
-
-
{$i18n.t('Name Tag')}
- -
-
- # -
- -
-
- -
-
{$i18n.t('Title')}
- -
- -
-
- -
-
{$i18n.t('Tags')}
- - -
-
- -
- -
-
-
-
-
-
- - diff --git a/src/lib/components/icons/ArrowLeft.svelte b/src/lib/components/icons/ArrowLeft.svelte new file mode 100644 index 000000000..166aee7f6 --- /dev/null +++ b/src/lib/components/icons/ArrowLeft.svelte @@ -0,0 +1,15 @@ + + + + + diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index ccc486d8b..954f99c38 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -10,6 +10,7 @@ showArchivedChats, showControls, showSidebar, + temporaryChatEnabled, user } from '$lib/stores'; @@ -23,6 +24,7 @@ import MenuLines from '../icons/MenuLines.svelte'; import AdjustmentsHorizontal from '../icons/AdjustmentsHorizontal.svelte'; import Map from '../icons/Map.svelte'; + import { stringify } from 'postcss'; const i18n = getContext('i18n'); @@ -74,8 +76,7 @@
- - {#if shareEnabled && chat && chat.id} + {#if shareEnabled && chat && (chat.id || $temporaryChatEnabled)} {}; const getChatAsText = async () => { - const _chat = chat.chat; - - const messages = createMessagesList(_chat.history, _chat.history.currentId); + const history = chat.chat.history; + const messages = createMessagesList(history, history.currentId); const chatText = messages.reduce((a, message, i, arr) => { return `${a}### ${message.role.toUpperCase()}\n${message.content}\n\n`; }, ''); @@ -52,12 +58,9 @@ }; const downloadPdf = async () => { - const _chat = chat.chat; - const messages = createMessagesList(_chat.history, _chat.history.currentId); - - console.log('download', chat); - - const blob = await downloadChatAsPDF(_chat.title, messages); + const history = chat.chat.history; + const messages = createMessagesList(history, history.currentId); + const blob = await downloadChatAsPDF(chat.chat.title, messages); // Create a URL for the blob const url = window.URL.createObjectURL(blob); @@ -65,7 +68,7 @@ // Create a link element to trigger the download const a = document.createElement('a'); a.href = url; - a.download = `chat-${_chat.title}.pdf`; + a.download = `chat-${chat.chat.title}.pdf`; // Append the link to the body and click it programmatically document.body.appendChild(a); @@ -79,6 +82,9 @@ }; const downloadJSONExport = async () => { + if (chat.id) { + chat = await getChatById(localStorage.token, chat.id); + } let blob = new Blob([JSON.stringify([chat])], { type: 'application/json' }); @@ -189,27 +195,29 @@
{$i18n.t('Copy')}
- { - shareHandler(); - }} - > - { + shareHandler(); + }} > - - -
{$i18n.t('Share')}
-
+ + + +
{$i18n.t('Share')}
+ + {/if} -
+ {#if !$temporaryChatEnabled} +
-
- -
+
+ +
+ {/if}
diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 0582eb574..ed7107c6f 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -19,7 +19,7 @@ showOverview, showControls } from '$lib/stores'; - import { onMount, getContext, tick } from 'svelte'; + import { onMount, getContext, tick, onDestroy } from 'svelte'; const i18n = getContext('i18n'); @@ -32,7 +32,10 @@ updateChatById, getAllChatTags, archiveChatById, - cloneChatById + cloneChatById, + getChatListBySearchText, + createNewChat, + getPinnedChatList } from '$lib/apis/chats'; import { WEBUI_BASE_URL } from '$lib/constants'; @@ -42,6 +45,9 @@ import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import Spinner from '../common/Spinner.svelte'; import Loader from '../common/Loader.svelte'; + import FilesOverlay from '../chat/MessageInput/FilesOverlay.svelte'; + import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte'; + import { select } from 'd3-selection'; const BREAKPOINT = 768; @@ -58,33 +64,11 @@ let selectedTagName = null; - let filteredChatList = []; - // Pagination variables let chatListLoading = false; let allChatsLoaded = false; - $: filteredChatList = $chats.filter((chat) => { - if (search === '') { - return true; - } else { - let title = chat.title.toLowerCase(); - const query = search.toLowerCase(); - - let contentMatches = false; - // Access the messages within chat.chat.messages - if (chat.chat && chat.chat.messages && Array.isArray(chat.chat.messages)) { - contentMatches = chat.chat.messages.some((message) => { - // Check if message.content exists and includes the search query - return message.content && message.content.toLowerCase().includes(query); - }); - } - - return title.includes(query) || contentMatches; - } - }); - - const enablePagination = async () => { + const initChatList = async () => { // Reset pagination variables currentChatPage.set(1); allChatsLoaded = false; @@ -98,7 +82,14 @@ chatListLoading = true; currentChatPage.set($currentChatPage + 1); - const newChatList = await getChatList(localStorage.token, $currentChatPage); + + let newChatList = []; + + if (search) { + newChatList = await getChatListBySearchText(localStorage.token, search, $currentChatPage); + } else { + newChatList = await getChatList(localStorage.token, $currentChatPage); + } // once the bottom of the list has been reached (no results) there is no need to continue querying allChatsLoaded = newChatList.length === 0; @@ -107,110 +98,26 @@ chatListLoading = false; }; - onMount(async () => { - mobile.subscribe((e) => { - if ($showSidebar && e) { - showSidebar.set(false); - } + let searchDebounceTimeout; - if (!$showSidebar && !e) { - showSidebar.set(true); - } - }); + const searchDebounceHandler = async () => { + console.log('search', search); + chats.set(null); + selectedTagName = null; - showSidebar.set(!$mobile ? localStorage.sidebar === 'true' : false); - showSidebar.subscribe((value) => { - localStorage.sidebar = value; - }); - - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); - await enablePagination(); - - let touchstart; - let touchend; - - function checkDirection() { - const screenWidth = window.innerWidth; - const swipeDistance = Math.abs(touchend.screenX - touchstart.screenX); - if (touchstart.clientX < 40 && swipeDistance >= screenWidth / 8) { - if (touchend.screenX < touchstart.screenX) { - showSidebar.set(false); - } - if (touchend.screenX > touchstart.screenX) { - showSidebar.set(true); - } - } + if (searchDebounceTimeout) { + clearTimeout(searchDebounceTimeout); } - const onTouchStart = (e) => { - touchstart = e.changedTouches[0]; - console.log(touchstart.clientX); - }; - - const onTouchEnd = (e) => { - touchend = e.changedTouches[0]; - checkDirection(); - }; - - const onKeyDown = (e) => { - if (e.key === 'Shift') { - shiftKey = true; - } - }; - - const onKeyUp = (e) => { - if (e.key === 'Shift') { - shiftKey = false; - } - }; - - const onFocus = () => {}; - - const onBlur = () => { - shiftKey = false; - selectedChatId = null; - }; - - window.addEventListener('keydown', onKeyDown); - window.addEventListener('keyup', onKeyUp); - - window.addEventListener('touchstart', onTouchStart); - window.addEventListener('touchend', onTouchEnd); - - window.addEventListener('focus', onFocus); - window.addEventListener('blur', onBlur); - - return () => { - window.removeEventListener('keydown', onKeyDown); - window.removeEventListener('keyup', onKeyUp); - - window.removeEventListener('touchstart', onTouchStart); - window.removeEventListener('touchend', onTouchEnd); - - window.removeEventListener('focus', onFocus); - window.removeEventListener('blur', onBlur); - }; - }); - - // Helper function to fetch and add chat content to each chat - const enrichChatsWithContent = async (chatList) => { - const enrichedChats = await Promise.all( - chatList.map(async (chat) => { - const chatDetails = await getChatById(localStorage.token, chat.id).catch((error) => null); // Handle error or non-existent chat gracefully - if (chatDetails) { - chat.chat = chatDetails.chat; // Assuming chatDetails.chat contains the chat content - } - return chat; - }) - ); - - await chats.set(enrichedChats); - }; - - const saveSettings = async (updated) => { - await settings.set({ ...$settings, ...updated }); - await updateUserSettings(localStorage.token, { ui: $settings }); - location.href = '/'; + if (search === '') { + await initChatList(); + return; + } else { + searchDebounceTimeout = setTimeout(async () => { + currentChatPage.set(1); + await chats.set(await getChatListBySearchText(localStorage.token, search)); + }, 1000); + } }; const deleteChatHandler = async (id) => { @@ -230,9 +137,175 @@ currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); } }; + + const inputFilesHandler = async (files) => { + console.log(files); + + for (const file of files) { + const reader = new FileReader(); + reader.onload = async (e) => { + const content = e.target.result; + + try { + const items = JSON.parse(content); + + for (const item of items) { + if (item.chat) { + await createNewChat(localStorage.token, item.chat); + } + } + } catch { + toast.error($i18n.t(`Invalid file format.`)); + } + + initChatList(); + }; + + reader.readAsText(file); + } + }; + + const tagEventHandler = async (type, tagName, chatId) => { + console.log(type, tagName, chatId); + if (type === 'delete') { + if (selectedTagName === tagName) { + if ($tags.map((t) => t.name).includes(tagName)) { + await chats.set(await getChatListByTagName(localStorage.token, tagName)); + } else { + selectedTagName = null; + await initChatList(); + } + } + } + }; + + let dragged = false; + + const onDragOver = (e) => { + e.preventDefault(); + dragged = true; + }; + + const onDragLeave = () => { + dragged = false; + }; + + const onDrop = async (e) => { + e.preventDefault(); + console.log(e); + + if (e.dataTransfer?.files) { + const inputFiles = Array.from(e.dataTransfer?.files); + if (inputFiles && inputFiles.length > 0) { + console.log(inputFiles); + inputFilesHandler(inputFiles); + } else { + toast.error($i18n.t(`File not found.`)); + } + } + + dragged = false; + }; + + let touchstart; + let touchend; + + function checkDirection() { + const screenWidth = window.innerWidth; + const swipeDistance = Math.abs(touchend.screenX - touchstart.screenX); + if (touchstart.clientX < 40 && swipeDistance >= screenWidth / 8) { + if (touchend.screenX < touchstart.screenX) { + showSidebar.set(false); + } + if (touchend.screenX > touchstart.screenX) { + showSidebar.set(true); + } + } + } + + const onTouchStart = (e) => { + touchstart = e.changedTouches[0]; + console.log(touchstart.clientX); + }; + + const onTouchEnd = (e) => { + touchend = e.changedTouches[0]; + checkDirection(); + }; + + const onKeyDown = (e) => { + if (e.key === 'Shift') { + shiftKey = true; + } + }; + + const onKeyUp = (e) => { + if (e.key === 'Shift') { + shiftKey = false; + } + }; + + const onFocus = () => {}; + + const onBlur = () => { + shiftKey = false; + selectedChatId = null; + }; + + onMount(async () => { + mobile.subscribe((e) => { + if ($showSidebar && e) { + showSidebar.set(false); + } + + if (!$showSidebar && !e) { + showSidebar.set(true); + } + }); + + showSidebar.set(!$mobile ? localStorage.sidebar === 'true' : false); + showSidebar.subscribe((value) => { + localStorage.sidebar = value; + }); + + await pinnedChats.set(await getPinnedChatList(localStorage.token)); + await initChatList(); + + window.addEventListener('keydown', onKeyDown); + window.addEventListener('keyup', onKeyUp); + + window.addEventListener('touchstart', onTouchStart); + window.addEventListener('touchend', onTouchEnd); + + window.addEventListener('focus', onFocus); + window.addEventListener('blur', onBlur); + + const dropZone = document.getElementById('sidebar'); + + dropZone?.addEventListener('dragover', onDragOver); + dropZone?.addEventListener('drop', onDrop); + dropZone?.addEventListener('dragleave', onDragLeave); + }); + + onDestroy(() => { + window.removeEventListener('keydown', onKeyDown); + window.removeEventListener('keyup', onKeyUp); + + window.removeEventListener('touchstart', onTouchStart); + window.removeEventListener('touchend', onTouchEnd); + + window.removeEventListener('focus', onFocus); + window.removeEventListener('blur', onBlur); + + const dropZone = document.getElementById('sidebar'); + + dropZone?.removeEventListener('dragover', onDragOver); + dropZone?.removeEventListener('drop', onDrop); + dropZone?.removeEventListener('dragleave', onDragLeave); + }); + {#if dragged} +
+
+ +
+
+ {/if}
{ - // TODO: migrate backend for more scalable search mechanism - scrollPaginationEnabled.set(false); - await chats.set(await getChatList(localStorage.token)); // when searching, load all chats - enrichChatsWithContent($chats); + on:input={() => { + searchDebounceHandler(); }} />
- {#if $tags.filter((t) => t.name !== 'pinned').length > 0} -
+ {#if $tags.length > 0} +
- {#each $tags.filter((t) => t.name !== 'pinned') as tag} + {#each $tags as tag}
diff --git a/src/lib/components/layout/Sidebar/ChatItem.svelte b/src/lib/components/layout/Sidebar/ChatItem.svelte index 4a633fddf..23a168ed3 100644 --- a/src/lib/components/layout/Sidebar/ChatItem.svelte +++ b/src/lib/components/layout/Sidebar/ChatItem.svelte @@ -12,6 +12,7 @@ deleteChatById, getChatList, getChatListByTagName, + getPinnedChatList, updateChatById } from '$lib/apis/chats'; import { @@ -55,7 +56,7 @@ currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); } }; @@ -70,7 +71,7 @@ currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); } }; @@ -79,7 +80,7 @@ currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); }; const focusEdit = async (node: HTMLInputElement) => { @@ -256,7 +257,10 @@ dispatch('unselect'); }} on:change={async () => { - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); + }} + on:tag={(e) => { + dispatch('tag', e.detail); }} > -
-
- - - -
- -{#if tags.length > 0} -
-
- doc?.selected === 'checked').length === - filteredDocs.length - ? 'checked' - : 'unchecked'} - indeterminate={filteredDocs.filter((doc) => doc?.selected === 'checked').length > 0 && - filteredDocs.filter((doc) => doc?.selected === 'checked').length !== filteredDocs.length} - on:change={(e) => { - if (e.detail === 'checked') { - filteredDocs = filteredDocs.map((doc) => ({ ...doc, selected: 'checked' })); - } else if (e.detail === 'unchecked') { - filteredDocs = filteredDocs.map((doc) => ({ ...doc, selected: 'unchecked' })); - } - }} - /> -
- - {#if filteredDocs.filter((doc) => doc?.selected === 'checked').length === 0} - - - {#each tags as tag} - - {/each} - {:else} -
-
- {filteredDocs.filter((doc) => doc?.selected === 'checked').length} Selected -
- -
- - - -
-
- {/if} -
-{/if} - -
- {#each filteredDocs as doc} - - - - - -
- - {/each} -
- -
- ⓘ {$i18n.t("Use '#' in the prompt input to load and select your documents.")} -
- -
-
- { - console.log(importFiles); - - const reader = new FileReader(); - reader.onload = async (event) => { - const savedDocs = JSON.parse(event.target.result); - console.log(savedDocs); - - for (const doc of savedDocs) { - await createNewDoc( - localStorage.token, - doc.collection_name, - doc.filename, - doc.name, - doc.title, - doc.content - ).catch((error) => { - toast.error(error); - return null; - }); - } - - await documents.set(await getDocs(localStorage.token)); - }; - - reader.readAsText(importFiles[0]); - }} - /> - - - - -
-
diff --git a/src/lib/components/workspace/Knowledge.svelte b/src/lib/components/workspace/Knowledge.svelte index 1706ba0aa..6ed4864e8 100644 --- a/src/lib/components/workspace/Knowledge.svelte +++ b/src/lib/components/workspace/Knowledge.svelte @@ -181,7 +181,8 @@ {/if}
- Updated {dayjs(item.updated_at * 1000).fromNow()} + {$i18n.t('Updated')} + {dayjs(item.updated_at * 1000).fromNow()}
diff --git a/src/lib/components/workspace/Knowledge/Collection.svelte b/src/lib/components/workspace/Knowledge/Collection.svelte index 606d4ff58..e59144ac7 100644 --- a/src/lib/components/workspace/Knowledge/Collection.svelte +++ b/src/lib/components/workspace/Knowledge/Collection.svelte @@ -1,6 +1,7 @@