Merge pull request #5861 from open-webui/projects

feat: knowledge/projects
This commit is contained in:
Timothy Jaeryang Baek
2024-10-04 10:00:47 +02:00
committed by GitHub
48 changed files with 2875 additions and 557 deletions

View File

@@ -1,3 +1,5 @@
# TODO: Merge this with the webui_app and make it a single app
import json
import logging
import mimetypes
@@ -634,9 +636,23 @@ def save_docs_to_vector_db(
metadata: Optional[dict] = None,
overwrite: bool = False,
split: bool = True,
add: bool = False,
) -> bool:
log.info(f"save_docs_to_vector_db {docs} {collection_name}")
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
result = VECTOR_DB_CLIENT.query(
collection_name=collection_name,
filter={"hash": metadata["hash"]},
)
if result:
existing_doc_ids = result.ids[0]
if existing_doc_ids:
log.info(f"Document with hash {metadata['hash']} already exists")
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
@@ -659,42 +675,46 @@ def save_docs_to_vector_db(
metadata[key] = str(value)
try:
if overwrite:
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
log.info(f"deleting existing collection {collection_name}")
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
log.info(f"collection {collection_name} already exists")
return True
else:
embedding_function = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embeddings = embedding_function(
list(map(lambda x: x.replace("\n", " "), texts))
)
if overwrite:
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
log.info(f"deleting existing collection {collection_name}")
VECTOR_DB_CLIENT.insert(
collection_name=collection_name,
items=[
{
"id": str(uuid.uuid4()),
"text": text,
"vector": embeddings[idx],
"metadata": metadatas[idx],
}
for idx, text in enumerate(texts)
],
)
if add is False:
return True
return True
log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embeddings = embedding_function(
list(map(lambda x: x.replace("\n", " "), texts))
)
items = [
{
"id": str(uuid.uuid4()),
"text": text,
"vector": embeddings[idx],
"metadata": metadatas[idx],
}
for idx, text in enumerate(texts)
]
VECTOR_DB_CLIENT.insert(
collection_name=collection_name,
items=items,
)
return True
except Exception as e:
log.exception(e)
return False
@@ -702,6 +722,7 @@ def save_docs_to_vector_db(
class ProcessFileForm(BaseModel):
file_id: str
content: Optional[str] = None
collection_name: Optional[str] = None
@@ -712,42 +733,91 @@ def process_file(
):
try:
file = Files.get_file_by_id(form_data.file_id)
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
collection_name = form_data.collection_name
if collection_name is None:
with open(file_path, "rb") as f:
collection_name = calculate_sha256(f)[:63]
collection_name = f"file-{file.id}"
loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
)
docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
text_content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {text_content}")
Files.update_files_metadata_by_id(
form_data.file_id,
{
"content": {
"text": text_content,
}
},
if form_data.content:
docs = [
Document(
page_content=form_data.content,
metadata={
"name": file.meta.get("name", file.filename),
"created_by": file.user_id,
**file.meta,
},
)
]
text_content = form_data.content
elif file.data.get("content", None):
docs = [
Document(
page_content=file.data.get("content", ""),
metadata={
"name": file.meta.get("name", file.filename),
"created_by": file.user_id,
**file.meta,
},
)
]
text_content = file.data.get("content", "")
else:
file_path = file.meta.get("path", None)
if file_path:
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
)
else:
docs = [
Document(
page_content=file.data.get("content", ""),
metadata={
"name": file.filename,
"created_by": file.user_id,
**file.meta,
},
)
]
text_content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {text_content}")
Files.update_file_data_by_id(
file.id,
{"content": text_content},
)
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
try:
result = save_docs_to_vector_db(
docs,
collection_name,
{
"file_id": form_data.file_id,
docs=docs,
collection_name=collection_name,
metadata={
"file_id": file.id,
"name": file.meta.get("name", file.filename),
"hash": hash,
},
add=(True if form_data.collection_name else False),
)
if result:
Files.update_file_metadata_by_id(
file.id,
{
"collection_name": collection_name,
},
)
return {
"status": True,
"collection_name": collection_name,
@@ -755,10 +825,7 @@ def process_file(
"content": text_content,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
raise e
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
@@ -769,7 +836,7 @@ def process_file(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=str(e),
)
@@ -1183,6 +1250,30 @@ def query_collection_handler(
####################################
class DeleteForm(BaseModel):
collection_name: str
file_id: str
@app.post("/delete")
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
file = Files.get_file_by_id(form_data.file_id)
hash = file.hash
VECTOR_DB_CLIENT.delete(
collection_name=form_data.collection_name,
metadata={"hash": hash},
)
return {"status": True}
else:
return {"status": False}
except Exception as e:
log.exception(e)
return {"status": False}
@app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
VECTOR_DB_CLIENT.reset()

View File

@@ -319,17 +319,25 @@ def get_rag_context(
for file in files:
if file.get("context") == "full":
context = {
"documents": [[file.get("file").get("content")]],
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
else:
context = None
collection_names = (
file["collection_names"]
if file["type"] == "collection"
else [file["collection_name"]] if file["collection_name"] else []
)
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
collection_names = file.get("collection_names", [])
else:
collection_names.append(file["id"])
elif file.get("collection_name"):
collection_names.append(file["collection_name"])
elif file.get("id"):
if file.get("legacy"):
collection_names.append(f"{file['id']}")
else:
collection_names.append(f"file-{file['id']}")
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:

View File

@@ -49,22 +49,52 @@ class ChromaClient:
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: int = 2
) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get(
where=filter,
limit=limit,
)
print(result)
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
@@ -111,11 +141,19 @@ class ChromaClient:
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
def delete(self, collection_name: str, ids: list[str]):
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
collection.delete(ids=ids)
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
def reset(self):
# Resets the database. This will delete all collections and item entries.

View File

@@ -135,6 +135,25 @@ class MilvusClient:
return self._result_to_search_result(result)
def query(
self, collection_name: str, filter: dict, limit: int = 1
) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
filter_string = " && ".join(
[
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
for key, value in filter.items()
]
)
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
limit=limit,
)
return self._result_to_get_result([result])
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
result = self.client.query(
@@ -187,13 +206,32 @@ class MilvusClient:
],
)
def delete(self, collection_name: str, ids: list[str]):
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
if ids:
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
elif filter:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join(
[
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
for key, value in filter.items()
]
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
)
def reset(self):
# Resets the database. This will delete all collections and item entries.