mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
wip
This commit is contained in:
22
backend/open_webui/retrieval/vector/connector.py
Normal file
22
backend/open_webui/retrieval/vector/connector.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from open_webui.config import VECTOR_DB
|
||||
|
||||
if VECTOR_DB == "milvus":
|
||||
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
elif VECTOR_DB == "opensearch":
|
||||
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
else:
|
||||
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
VECTOR_DB_CLIENT = ChromaClient()
|
||||
174
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file
174
backend/open_webui/retrieval/vector/dbs/chroma.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import chromadb
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
CHROMA_HTTP_PORT,
|
||||
CHROMA_HTTP_HEADERS,
|
||||
CHROMA_HTTP_SSL,
|
||||
CHROMA_TENANT,
|
||||
CHROMA_DATABASE,
|
||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
)
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
"anonymized_telemetry": False,
|
||||
}
|
||||
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
||||
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
||||
settings_dict["chroma_client_auth_credentials"] = (
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||
)
|
||||
|
||||
if CHROMA_HTTP_HOST != "":
|
||||
self.client = chromadb.HttpClient(
|
||||
host=CHROMA_HTTP_HOST,
|
||||
port=CHROMA_HTTP_PORT,
|
||||
headers=CHROMA_HTTP_HEADERS,
|
||||
ssl=CHROMA_HTTP_SSL,
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
settings=Settings(**settings_dict),
|
||||
)
|
||||
else:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=CHROMA_DATA_PATH,
|
||||
settings=Settings(**settings_dict),
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
collections = self.client.list_collections()
|
||||
return collection_name in [collection.name for collection in collections]
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
return self.client.delete_collection(name=collection_name)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.query(
|
||||
query_embeddings=vectors,
|
||||
n_results=limit,
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
# Query the items from the collection based on the filter.
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.get(
|
||||
where=filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.get()
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
}
|
||||
)
|
||||
return None
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [item["metadata"] for item in items]
|
||||
|
||||
for batch in create_batches(
|
||||
api=self.client,
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [item["metadata"] for item in items]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
return self.client.reset()
|
||||
286
backend/open_webui/retrieval/vector/dbs/milvus.py
Normal file
286
backend/open_webui/retrieval/vector/dbs/milvus.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
)
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
self.client = Client(uri=MILVUS_URI)
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_documents.append(item.get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_distances = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_distances.append(item.get("distance"))
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
schema = self.client.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=True,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=65535,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
description="vector",
|
||||
)
|
||||
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
||||
schema.add_field(
|
||||
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
||||
)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="HNSW",
|
||||
metric_type="COSINE",
|
||||
params={"M": 16, "efConstruction": 100},
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.drop_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
max_limit = 16383 # The maximum number of records per request
|
||||
all_results = []
|
||||
|
||||
if limit is None:
|
||||
limit = float("inf") # Use infinity as a placeholder for no limit
|
||||
|
||||
# Initialize offset and remaining to handle pagination
|
||||
offset = 0
|
||||
remaining = limit
|
||||
|
||||
try:
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
print("remaining", remaining)
|
||||
current_fetch = min(
|
||||
max_limit, remaining
|
||||
) # Determine how many items to fetch in this iteration
|
||||
|
||||
results = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
output_fields=["*"],
|
||||
limit=current_fetch,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
if not results:
|
||||
break
|
||||
|
||||
all_results.extend(results)
|
||||
results_count = len(results)
|
||||
remaining -= (
|
||||
results_count # Decrease remaining by the number of items fetched
|
||||
)
|
||||
offset += results_count
|
||||
|
||||
# Break the loop if the results returned are less than the requested fetch count
|
||||
if results_count < current_fetch:
|
||||
break
|
||||
|
||||
print(all_results)
|
||||
return self._result_to_get_result([all_results])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter='id != ""',
|
||||
)
|
||||
return self._result_to_get_result([result])
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if ids:
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
# Convert the filter dictionary to a string using JSON_CONTAINS.
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.list_collections()
|
||||
for collection_name in collection_names:
|
||||
if collection_name.startswith(self.collection_prefix):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
178
backend/open_webui/retrieval/vector/dbs/opensearch.py
Normal file
178
backend/open_webui/retrieval/vector/dbs/opensearch.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from opensearchpy import OpenSearch
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
OPENSEARCH_CERT_VERIFY,
|
||||
OPENSEARCH_USERNAME,
|
||||
OPENSEARCH_PASSWORD,
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.client = OpenSearch(
|
||||
hosts=[OPENSEARCH_URI],
|
||||
use_ssl=OPENSEARCH_SSL,
|
||||
verify_certs=OPENSEARCH_CERT_VERIFY,
|
||||
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
||||
)
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
|
||||
def _create_index(self, index_name: str, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": true,
|
||||
"similarity": "faiss",
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "ip", # Use inner product to approximate cosine similarity
|
||||
"engine": "faiss",
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
},
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
}
|
||||
}
|
||||
}
|
||||
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
def has_collection(self, index_name: str) -> bool:
|
||||
# has_collection here means has index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
|
||||
|
||||
def delete_colleciton(self, index_name: str):
|
||||
# delete_collection here means delete index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
|
||||
|
||||
def search(
|
||||
self, index_name: str, vectors: list[list[float]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def get_or_create_index(self, index_name: str, dimension: int):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension)
|
||||
|
||||
def get(self, index_name: str) -> Optional[GetResult]:
|
||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
)
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
def insert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
def upsert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
def delete(self, index_name: str, ids: list[str]):
|
||||
actions = [
|
||||
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
|
||||
for id in ids
|
||||
]
|
||||
self.client.bulk(body=actions)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
354
backend/open_webui/retrieval/vector/dbs/pgvector.py
Normal file
354
backend/open_webui/retrieval/vector/dbs/pgvector.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy import (
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
values,
|
||||
)
|
||||
from sqlalchemy.sql import true
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL
|
||||
|
||||
VECTOR_LENGTH = 1536
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.internal.db import Session
|
||||
|
||||
self.session = Session
|
||||
else:
|
||||
engine = create_engine(
|
||||
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
self.session = scoped_session(SessionLocal)
|
||||
|
||||
try:
|
||||
# Ensure the pgvector extension is available
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
# Create the tables if they do not exist
|
||||
# Base.metadata.create_all requires a bind (engine or connection)
|
||||
# Get the connection from the session
|
||||
connection = self.session.connection()
|
||||
Base.metadata.create_all(bind=connection)
|
||||
|
||||
# Create an index on the vector column if it doesn't exist
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
||||
)
|
||||
)
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
"ON document_chunk (collection_name);"
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
print("Initialization complete.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
# Adjust vector to have length VECTOR_LENGTH
|
||||
current_length = len(vector)
|
||||
if current_length < VECTOR_LENGTH:
|
||||
# Pad the vector with zeros
|
||||
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
||||
elif current_length > VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
|
||||
)
|
||||
return vector
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
print(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: List[List[float]],
|
||||
limit: Optional[int] = None,
|
||||
) -> Optional[SearchResult]:
|
||||
try:
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
# Adjust query vectors to VECTOR_LENGTH
|
||||
vectors = [self.adjust_vector_length(vector) for vector in vectors]
|
||||
num_queries = len(vectors)
|
||||
|
||||
def vector_expr(vector):
|
||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
||||
|
||||
# Create the values for query vectors
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
subq = (
|
||||
select(
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
).label("distance"),
|
||||
)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
subq = subq.lateral("result")
|
||||
|
||||
# Build the main query by joining query_vectors and the lateral subquery
|
||||
stmt = (
|
||||
select(
|
||||
query_vectors.c.qid,
|
||||
subq.c.id,
|
||||
subq.c.text,
|
||||
subq.c.vmetadata,
|
||||
subq.c.distance,
|
||||
)
|
||||
.select_from(query_vectors)
|
||||
.join(subq, true())
|
||||
.order_by(query_vectors.c.qid, subq.c.distance)
|
||||
)
|
||||
|
||||
result_proxy = self.session.execute(stmt)
|
||||
results = result_proxy.all()
|
||||
|
||||
ids = [[] for _ in range(num_queries)]
|
||||
distances = [[] for _ in range(num_queries)]
|
||||
documents = [[] for _ in range(num_queries)]
|
||||
metadatas = [[] for _ in range(num_queries)]
|
||||
|
||||
if not results:
|
||||
return SearchResult(
|
||||
ids=ids,
|
||||
distances=distances,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for row in results:
|
||||
qid = int(row.qid)
|
||||
ids[qid].append(row.id)
|
||||
distances[qid].append(row.distance)
|
||||
documents[qid].append(row.text)
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
print(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
print(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
print(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
print(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
try:
|
||||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
return exists
|
||||
except Exception as e:
|
||||
print(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
print(f"Collection '{collection_name}' deleted.")
|
||||
184
backend/open_webui/retrieval/vector/dbs/qdrant.py
Normal file
184
backend/open_webui/retrieval/vector/dbs/qdrant.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.QDRANT_API_KEY = QDRANT_API_KEY
|
||||
self.client = (
|
||||
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
if self.QDRANT_URI
|
||||
else None
|
||||
)
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension, distance=models.Distance.COSINE
|
||||
),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
|
||||
def _create_points(self, items: list[VectorItem]):
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={"text": item["text"], "metadata": item["metadata"]},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(
|
||||
f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
query_response = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query=vectors[0],
|
||||
limit=limit,
|
||||
)
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
try:
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
)
|
||||
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
field_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
),
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
),
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=field_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
19
backend/open_webui/retrieval/vector/main.py
Normal file
19
backend/open_webui/retrieval/vector/main.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any
|
||||
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
vector: List[float | int]
|
||||
metadata: Any
|
||||
|
||||
|
||||
class GetResult(BaseModel):
|
||||
ids: Optional[List[List[str]]]
|
||||
documents: Optional[List[List[str]]]
|
||||
metadatas: Optional[List[List[Any]]]
|
||||
|
||||
|
||||
class SearchResult(GetResult):
|
||||
distances: Optional[List[List[float | int]]]
|
||||
Reference in New Issue
Block a user