mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'upstream-dev' into dev
This commit is contained in:
@@ -450,7 +450,7 @@ def transcribe(file_path):
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise error_detail
|
||||
raise Exception(error_detail)
|
||||
|
||||
|
||||
@app.post("/transcriptions")
|
||||
|
||||
@@ -125,22 +125,34 @@ async def comfyui_generate_image(
|
||||
workflow[node_id]["inputs"][node.key] = model
|
||||
elif node.type == "prompt":
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"]["text"] = payload.prompt
|
||||
workflow[node_id]["inputs"][
|
||||
node.key if node.key else "text"
|
||||
] = payload.prompt
|
||||
elif node.type == "negative_prompt":
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"]["text"] = payload.negative_prompt
|
||||
workflow[node_id]["inputs"][
|
||||
node.key if node.key else "text"
|
||||
] = payload.negative_prompt
|
||||
elif node.type == "width":
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"]["width"] = payload.width
|
||||
workflow[node_id]["inputs"][
|
||||
node.key if node.key else "width"
|
||||
] = payload.width
|
||||
elif node.type == "height":
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"]["height"] = payload.height
|
||||
workflow[node_id]["inputs"][
|
||||
node.key if node.key else "height"
|
||||
] = payload.height
|
||||
elif node.type == "n":
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"]["batch_size"] = payload.n
|
||||
workflow[node_id]["inputs"][
|
||||
node.key if node.key else "batch_size"
|
||||
] = payload.n
|
||||
elif node.type == "steps":
|
||||
for node_id in node.node_ids:
|
||||
workflow[node_id]["inputs"]["steps"] = payload.steps
|
||||
workflow[node_id]["inputs"][
|
||||
node.key if node.key else "steps"
|
||||
] = payload.steps
|
||||
elif node.type == "seed":
|
||||
seed = (
|
||||
payload.seed
|
||||
|
||||
@@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
|
||||
|
||||
class GenerateEmbedForm(BaseModel):
|
||||
model: str
|
||||
input: str
|
||||
truncate: Optional[bool]
|
||||
input: list[str]|str
|
||||
truncate: Optional[bool] = None
|
||||
options: Optional[dict] = None
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
|
||||
@@ -560,48 +560,7 @@ async def generate_embeddings(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embed",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except Exception:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return generate_ollama_batch_embeddings(form_data, url_idx)
|
||||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
@@ -611,48 +570,7 @@ async def generate_embeddings(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embeddings",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except Exception:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
@@ -692,7 +610,64 @@ def generate_ollama_embeddings(
|
||||
log.info(f"generate_ollama_embeddings {data}")
|
||||
|
||||
if "embedding" in data:
|
||||
return data["embedding"]
|
||||
return data
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except Exception:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
form_data: GenerateEmbedForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embed",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
log.info(f"generate_ollama_batch_embeddings {data}")
|
||||
|
||||
if "embeddings" in data:
|
||||
return data
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
@@ -788,8 +763,7 @@ async def generate_chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
payload = {**form_data.model_dump(exclude_none=True)}
|
||||
log.debug(f"{payload = }")
|
||||
|
||||
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
@@ -824,7 +798,7 @@ async def generate_chat_completion(
|
||||
|
||||
url = get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
log.debug(payload)
|
||||
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/chat",
|
||||
|
||||
@@ -63,7 +63,7 @@ from open_webui.config import (
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
RAG_EMBEDDING_BATCH_SIZE,
|
||||
RAG_FILE_MAX_COUNT,
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -267,7 +267,7 @@ async def get_status():
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
}
|
||||
|
||||
|
||||
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"status": True,
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
|
||||
|
||||
@app.post("/embedding/update")
|
||||
@@ -320,11 +320,7 @@ async def update_embedding_config(
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
||||
form_data.openai_config.batch_size
|
||||
if form_data.openai_config.batch_size
|
||||
else 1
|
||||
)
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
@@ -334,17 +330,17 @@ async def update_embedding_config(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -645,7 +641,7 @@ def save_docs_to_vector_db(
|
||||
filter={"hash": metadata["hash"]},
|
||||
)
|
||||
|
||||
if result:
|
||||
if result is not None:
|
||||
existing_doc_ids = result.ids[0]
|
||||
if existing_doc_ids:
|
||||
log.info(f"Document with hash {metadata['hash']} already exists")
|
||||
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
@@ -767,7 +763,7 @@ def process_file(
|
||||
collection_name=f"file-{file.id}", filter={"file_id": file.id}
|
||||
)
|
||||
|
||||
if len(result.ids[0]) > 0:
|
||||
if result is not None and len(result.ids[0]) > 0:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result.documents[0][idx],
|
||||
|
||||
@@ -12,8 +12,8 @@ from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbedForm,
|
||||
generate_ollama_batch_embeddings,
|
||||
)
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
@@ -193,7 +193,8 @@ def query_collection(
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
results.append(result.model_dump())
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
@@ -265,39 +266,27 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
if embedding_engine == "ollama":
|
||||
func = lambda query: generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
func = lambda query: generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
func = lambda query: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key if embedding_engine == "openai" else "",
|
||||
url=openai_url if embedding_engine == "openai" else "",
|
||||
)
|
||||
|
||||
def generate_multiple(query, f):
|
||||
def generate_multiple(query, func):
|
||||
if isinstance(query, list):
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return f(query)
|
||||
return func(query)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
@@ -445,20 +434,6 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
key: str,
|
||||
url: str = "https://api.openai.com/v1",
|
||||
):
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
@@ -482,6 +457,33 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": text})
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": [text]})
|
||||
)
|
||||
return (
|
||||
embeddings["embeddings"][0]
|
||||
if isinstance(text, str)
|
||||
else embeddings["embeddings"]
|
||||
)
|
||||
elif engine == "openai":
|
||||
key = kwargs.get("key", "")
|
||||
url = kwargs.get("url", "https://api.openai.com/v1")
|
||||
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
|
||||
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
else:
|
||||
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
|
||||
176
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
Normal file
176
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(
|
||||
collection_name=collection_name
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
|
||||
def _create_points(self, items: list[VectorItem]):
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"]
|
||||
},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
query_response = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query=vectors[0],
|
||||
limit=limit,
|
||||
)
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]]
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
try:
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
|
||||
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
field_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
),
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
),
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(
|
||||
must=field_conditions
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
@@ -4,8 +4,13 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Chat DB Schema
|
||||
@@ -18,13 +23,16 @@ class Chat(Base):
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
title = Column(Text)
|
||||
chat = Column(Text) # Save Chat JSON as Text
|
||||
chat = Column(JSON)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
share_id = Column(Text, unique=True, nullable=True)
|
||||
archived = Column(Boolean, default=False)
|
||||
pinned = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
meta = Column(JSON, server_default="{}")
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
@@ -33,13 +41,16 @@ class ChatModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
title: str
|
||||
chat: str
|
||||
chat: dict
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
share_id: Optional[str] = None
|
||||
archived: bool = False
|
||||
pinned: Optional[bool] = False
|
||||
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
####################
|
||||
@@ -64,6 +75,8 @@ class ChatResponse(BaseModel):
|
||||
created_at: int # timestamp in epoch
|
||||
share_id: Optional[str] = None # id of the chat to be shared
|
||||
archived: bool
|
||||
pinned: Optional[bool] = False
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
class ChatTitleIdResponse(BaseModel):
|
||||
@@ -86,7 +99,7 @@ class ChatTable:
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": json.dumps(form_data.chat),
|
||||
"chat": form_data.chat,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
@@ -101,14 +114,14 @@ class ChatTable:
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat_obj = db.get(Chat, id)
|
||||
chat_obj.chat = json.dumps(chat)
|
||||
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
|
||||
chat_obj.updated_at = int(time.time())
|
||||
chat_item = db.get(Chat, id)
|
||||
chat_item.chat = chat
|
||||
chat_item.title = chat["title"] if "title" in chat else "New Chat"
|
||||
chat_item.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat_obj)
|
||||
db.refresh(chat_item)
|
||||
|
||||
return ChatModel.model_validate(chat_obj)
|
||||
return ChatModel.model_validate(chat_item)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -182,11 +195,24 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.pinned = not chat.pinned
|
||||
chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.archived = not chat.archived
|
||||
chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
@@ -249,10 +275,10 @@ class ChatTable:
|
||||
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
|
||||
@@ -328,6 +354,15 @@ class ChatTable:
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, pinned=True)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
@@ -337,6 +372,207 @@ class ChatTable:
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_user_id_and_search_text(
|
||||
self,
|
||||
user_id: str,
|
||||
search_text: str,
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 60,
|
||||
) -> list[ChatModel]:
|
||||
"""
|
||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||
"""
|
||||
search_text = search_text.lower().strip()
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
|
||||
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
if not include_archived:
|
||||
query = query.filter(Chat.archived == False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
||||
dialect_name = db.bind.dialect.name
|
||||
if dialect_name == "sqlite":
|
||||
# SQLite case: using JSON1 extension for JSON searching
|
||||
query = query.filter(
|
||||
(
|
||||
Chat.title.ilike(
|
||||
f"%{search_text}%"
|
||||
) # Case-insensitive search in title
|
||||
| text(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.chat, '$.messages') AS message
|
||||
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
# PostgreSQL relies on proper JSON query for search
|
||||
query = query.filter(
|
||||
(
|
||||
Chat.title.ilike(
|
||||
f"%{search_text}%"
|
||||
) # Case-insensitive search in title
|
||||
| text(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements(Chat.chat->'messages') AS message
|
||||
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
|
||||
|
||||
def get_chat_list_by_user_id_and_tag_name(
|
||||
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
print(db.bind.dialect.name)
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 querying for tags within the meta JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
all_chats = query.all()
|
||||
print("all_chats", all_chats)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
) -> Optional[ChatModel]:
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
|
||||
if tag is None:
|
||||
tag = Tags.insert_new_tag(tag_name, user_id)
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
|
||||
tag_id = tag.id
|
||||
if tag_id not in chat.meta.get("tags", []):
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": chat.meta.get("tags", []) + [tag_id],
|
||||
}
|
||||
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
|
||||
with get_db() as db: # Assuming `get_db()` returns a session object
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
|
||||
# Normalize the tag_name for consistency
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
# Get the count of matching records
|
||||
count = query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
print(f"Count of chats for tag '{tag_name}':", count)
|
||||
|
||||
return count
|
||||
|
||||
def delete_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
tags = [tag for tag in tags if tag != tag_id]
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": tags,
|
||||
}
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": [],
|
||||
}
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -50,6 +50,14 @@ class FileModel(BaseModel):
|
||||
####################
|
||||
|
||||
|
||||
class FileMeta(BaseModel):
|
||||
name: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FileModelResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
@@ -57,12 +65,19 @@ class FileModelResponse(BaseModel):
|
||||
|
||||
filename: str
|
||||
data: Optional[dict] = None
|
||||
meta: dict
|
||||
meta: FileMeta
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FileMetadataResponse(BaseModel):
|
||||
id: str
|
||||
meta: dict
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FileForm(BaseModel):
|
||||
id: str
|
||||
hash: Optional[str] = None
|
||||
@@ -104,6 +119,19 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
return FileMetadataResponse(
|
||||
id=file.id,
|
||||
meta=file.meta,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_files(self) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||
@@ -118,6 +146,21 @@ class FilesTable:
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FileMetadataResponse(
|
||||
id=file.id,
|
||||
meta=file.meta,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
)
|
||||
for file in db.query(File)
|
||||
.filter(File.id.in_(ids))
|
||||
.order_by(File.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
|
||||
@@ -6,6 +6,10 @@ import uuid
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.apps.webui.models.files import FileMetadataResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
@@ -64,6 +68,8 @@ class KnowledgeResponse(BaseModel):
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
files: Optional[list[FileMetadataResponse | dict]] = None
|
||||
|
||||
|
||||
class KnowledgeForm(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -4,53 +4,32 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Column, String, JSON
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Tag DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String)
|
||||
user_id = Column(String)
|
||||
data = Column(Text, nullable=True)
|
||||
|
||||
|
||||
class ChatIdTag(Base):
|
||||
__tablename__ = "chatidtag"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
tag_name = Column(String)
|
||||
chat_id = Column(String)
|
||||
user_id = Column(String)
|
||||
timestamp = Column(BigInteger)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
|
||||
class TagModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
user_id: str
|
||||
data: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChatIdTagModel(BaseModel):
|
||||
id: str
|
||||
tag_name: str
|
||||
chat_id: str
|
||||
user_id: str
|
||||
timestamp: int
|
||||
|
||||
meta: Optional[dict] = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel):
|
||||
####################
|
||||
|
||||
|
||||
class ChatIdTagForm(BaseModel):
|
||||
tag_name: str
|
||||
class TagChatIdForm(BaseModel):
|
||||
name: str
|
||||
chat_id: str
|
||||
|
||||
|
||||
class TagChatIdsResponse(BaseModel):
|
||||
chat_ids: list[str]
|
||||
|
||||
|
||||
class ChatTagsResponse(BaseModel):
|
||||
tags: list[str]
|
||||
|
||||
|
||||
class TagTable:
|
||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
id = name.replace(" ", "_").lower()
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
try:
|
||||
result = Tag(**tag.model_dump())
|
||||
@@ -93,170 +64,38 @@ class TagTable:
|
||||
self, name: str, user_id: str
|
||||
) -> Optional[TagModel]:
|
||||
try:
|
||||
id = name.replace(" ", "_").lower()
|
||||
with get_db() as db:
|
||||
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
|
||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||
return TagModel.model_validate(tag)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def add_tag_to_chat(
|
||||
self, user_id: str, form_data: ChatIdTagForm
|
||||
) -> Optional[ChatIdTagModel]:
|
||||
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
|
||||
if tag is None:
|
||||
tag = self.insert_new_tag(form_data.tag_name, user_id)
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
chatIdTag = ChatIdTagModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"chat_id": form_data.chat_id,
|
||||
"tag_name": tag.name,
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = ChatIdTag(**chatIdTag.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return ChatIdTagModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
tag_names = [
|
||||
chat_id_tag.tag_name
|
||||
for chat_id_tag in (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(ChatIdTag.timestamp.desc())
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (
|
||||
db.query(Tag)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter(Tag.name.in_(tag_names))
|
||||
.all()
|
||||
)
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_chat_id_and_user_id(
|
||||
self, chat_id: str, user_id: str
|
||||
) -> list[TagModel]:
|
||||
def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
tag_names = [
|
||||
chat_id_tag.tag_name
|
||||
for chat_id_tag in (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(user_id=user_id, chat_id=chat_id)
|
||||
.order_by(ChatIdTag.timestamp.desc())
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (
|
||||
db.query(Tag)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter(Tag.name.in_(tag_names))
|
||||
.all()
|
||||
)
|
||||
for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
|
||||
]
|
||||
|
||||
def get_chat_ids_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str
|
||||
) -> list[ChatIdTagModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
ChatIdTagModel.model_validate(chat_id_tag)
|
||||
for chat_id_tag in (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(user_id=user_id, tag_name=tag_name)
|
||||
.order_by(ChatIdTag.timestamp.desc())
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
def count_chat_ids_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str
|
||||
) -> int:
|
||||
with get_db() as db:
|
||||
return (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(tag_name=tag_name, user_id=user_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
res = (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(tag_name=tag_name, user_id=user_id)
|
||||
.delete()
|
||||
)
|
||||
id = name.replace(" ", "_").lower()
|
||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||
log.debug(f"res: {res}")
|
||||
db.commit()
|
||||
|
||||
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
|
||||
tag_name, user_id
|
||||
)
|
||||
if tag_count == 0:
|
||||
# Remove tag item from Tag col as well
|
||||
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||
self, tag_name: str, chat_id: str, user_id: str
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
res = (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
|
||||
.delete()
|
||||
)
|
||||
log.debug(f"res: {res}")
|
||||
db.commit()
|
||||
|
||||
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
|
||||
tag_name, user_id
|
||||
)
|
||||
if tag_count == 0:
|
||||
# Remove tag item from Tag col as well
|
||||
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
|
||||
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
|
||||
|
||||
for tag in tags:
|
||||
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||
tag.tag_name, chat_id, user_id
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
Tags = TagTable()
|
||||
|
||||
@@ -18,6 +18,8 @@ from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import Response
|
||||
@@ -27,6 +29,7 @@ from open_webui.utils.utils import (
|
||||
create_api_key,
|
||||
create_token,
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
)
|
||||
@@ -53,6 +56,8 @@ async def get_session_user(
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -71,7 +76,7 @@ async def get_session_user(
|
||||
|
||||
@router.post("/update/profile", response_model=UserResponse)
|
||||
async def update_profile(
|
||||
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
|
||||
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
||||
):
|
||||
if session_user:
|
||||
user = Users.update_user_by_id(
|
||||
@@ -166,6 +171,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -236,6 +243,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
|
||||
@@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
|
||||
Chats,
|
||||
ChatTitleIdResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.tags import (
|
||||
ChatIdTagForm,
|
||||
ChatIdTagModel,
|
||||
TagModel,
|
||||
Tags,
|
||||
)
|
||||
from open_webui.apps.webui.models.tags import TagModel, Tags
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@@ -95,7 +91,7 @@ async def get_user_chat_list_by_user_id(
|
||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
chat = Chats.insert_new_chat(user.id, form_data)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**chat.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
@@ -108,10 +104,46 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/search", response_model=list[ChatTitleIdResponse])
|
||||
async def search_user_chats(
|
||||
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
|
||||
):
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
return [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_user_id_and_search_text(
|
||||
user.id, text, skip=skip, limit=limit
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetPinnedChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/pinned", response_model=list[ChatResponse])
|
||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all", response_model=list[ChatResponse])
|
||||
async def get_user_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
@@ -124,11 +156,28 @@ async def get_user_chats(user=Depends(get_verified_user)):
|
||||
@router.get("/all/archived", response_model=list[ChatResponse])
|
||||
async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetAllTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all/tags", response_model=list[TagModel])
|
||||
async def get_all_user_tags(user=Depends(get_verified_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetAllChatsInDB
|
||||
############################
|
||||
@@ -141,10 +190,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_chats()
|
||||
]
|
||||
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
|
||||
|
||||
|
||||
############################
|
||||
@@ -187,7 +233,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id(share_id)
|
||||
|
||||
if chat:
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -199,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
class TagNameForm(BaseModel):
|
||||
class TagForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TagFilterForm(TagForm):
|
||||
skip: Optional[int] = 0
|
||||
limit: Optional[int] = 50
|
||||
|
||||
|
||||
@router.post("/tags", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_tag_name(
|
||||
form_data: TagNameForm, user=Depends(get_verified_user)
|
||||
form_data: TagFilterForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat_ids = [
|
||||
chat_id_tag.chat_id
|
||||
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
|
||||
form_data.name, user.id
|
||||
)
|
||||
]
|
||||
|
||||
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
|
||||
|
||||
chats = Chats.get_chat_list_by_user_id_and_tag_name(
|
||||
user.id, form_data.name, form_data.skip, form_data.limit
|
||||
)
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
||||
|
||||
return chats
|
||||
|
||||
|
||||
############################
|
||||
# GetAllTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/tags/all", response_model=list[TagModel])
|
||||
async def get_all_tags(user=Depends(get_verified_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChatById
|
||||
############################
|
||||
@@ -251,7 +278,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
|
||||
if chat:
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**chat.model_dump())
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -269,10 +297,9 @@ async def update_chat_by_id(
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
updated_chat = {**json.loads(chat.chat), **form_data.chat}
|
||||
|
||||
updated_chat = {**chat.chat, **form_data.chat}
|
||||
chat = Chats.update_chat_by_id(id, updated_chat)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -303,25 +330,57 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
# GetPinnedStatusById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/pinned", response_model=Optional[bool])
|
||||
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
return chat.pinned
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# PinChatById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
|
||||
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_pinned_by_id(id)
|
||||
return chat
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CloneChat
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat_body = json.loads(chat.chat)
|
||||
updated_chat = {
|
||||
**chat_body,
|
||||
**chat.chat,
|
||||
"originalChatId": chat.id,
|
||||
"branchPointMessageId": chat_body["history"]["currentId"],
|
||||
"branchPointMessageId": chat.chat["history"]["currentId"],
|
||||
"title": f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
@@ -333,12 +392,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
@@ -356,9 +415,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
if chat:
|
||||
if chat.share_id:
|
||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
||||
return ChatResponse(
|
||||
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
|
||||
)
|
||||
return ChatResponse(**shared_chat.model_dump())
|
||||
|
||||
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
|
||||
if not shared_chat:
|
||||
@@ -366,10 +423,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
return ChatResponse(**shared_chat.model_dump())
|
||||
|
||||
return ChatResponse(
|
||||
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -407,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
@router.get("/{id}/tags", response_model=list[TagModel])
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if tags != None:
|
||||
return tags
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -422,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
|
||||
async def add_chat_tag_by_id(
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
|
||||
@router.post("/{id}/tags", response_model=list[TagModel])
|
||||
async def add_tag_by_id_and_tag_name(
|
||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
||||
):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = form_data.name.replace(" ", "_").lower()
|
||||
|
||||
if form_data.tag_name not in tags:
|
||||
tag = Tags.add_tag_to_chat(user.id, form_data)
|
||||
|
||||
if tag:
|
||||
return tag
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
print(tags, tag_id)
|
||||
if tag_id not in tags:
|
||||
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
id, user.id, form_data.name
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
@@ -449,16 +506,20 @@ async def add_chat_tag_by_id(
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/tags", response_model=Optional[bool])
|
||||
async def delete_chat_tag_by_id(
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
|
||||
@router.delete("/{id}/tags", response_model=list[TagModel])
|
||||
async def delete_tag_by_id_and_tag_name(
|
||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
||||
):
|
||||
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||
form_data.tag_name, id, user.id
|
||||
)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
|
||||
|
||||
if result:
|
||||
return result
|
||||
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -472,10 +533,17 @@ async def delete_chat_tag_by_id(
|
||||
|
||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
|
||||
if result:
|
||||
return result
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
||||
@@ -213,7 +213,7 @@ async def update_file_data_content_by_id(
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/content", response_model=Optional[FileModel])
|
||||
@router.get("/{id}/content")
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
@@ -223,7 +223,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
|
||||
}
|
||||
return FileResponse(file_path, headers=headers)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -236,7 +239,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
|
||||
@router.get("/{id}/content/{file_name}")
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
@@ -248,7 +251,10 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
|
||||
}
|
||||
return FileResponse(file_path, headers=headers)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -48,7 +48,12 @@ async def get_knowledge_items(
|
||||
)
|
||||
else:
|
||||
return [
|
||||
KnowledgeResponse(**knowledge.model_dump())
|
||||
KnowledgeResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_file_metadatas_by_ids(
|
||||
knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
),
|
||||
)
|
||||
for knowledge in Knowledges.get_knowledge_items()
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user