chore: format

This commit is contained in:
Timothy Jaeryang Baek
2024-11-16 23:46:12 -08:00
parent c24bc60d35
commit c338f2cae1
64 changed files with 2668 additions and 1212 deletions

View File

@@ -27,7 +27,9 @@ class ChromaClient:
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = CHROMA_CLIENT_AUTH_CREDENTIALS
settings_dict["chroma_client_auth_credentials"] = (
CHROMA_CLIENT_AUTH_CREDENTIALS
)
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(

View File

@@ -7,9 +7,10 @@ from open_webui.config import (
OPENSEARCH_SSL,
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_USERNAME,
OPENSEARCH_PASSWORD
OPENSEARCH_PASSWORD,
)
class OpenSearchClient:
def __init__(self):
self.index_prefix = "open_webui"
@@ -25,10 +26,10 @@ class OpenSearchClient:
documents = []
metadatas = []
for hit in result['hits']['hits']:
ids.append(hit['_id'])
documents.append(hit['_source'].get("text"))
metadatas.append(hit['_source'].get("metadata"))
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
@@ -38,13 +39,15 @@ class OpenSearchClient:
documents = []
metadatas = []
for hit in result['hits']['hits']:
ids.append(hit['_id'])
distances.append(hit['_score'])
documents.append(hit['_source'].get("text"))
metadatas.append(hit['_source'].get("metadata"))
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
def _create_index(self, index_name: str, dimension: int):
body = {
@@ -52,20 +55,20 @@ class OpenSearchClient:
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"similarity": "faiss",
"method": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
"m": 16
}
"m": 16,
},
},
"text": {"type": "text"},
"metadata": {"type": "object"}
"metadata": {"type": "object"},
}
}
}
@@ -73,19 +76,21 @@ class OpenSearchClient:
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i:i + batch_size]
yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool:
# has_collection here means has index.
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
def delete_colleciton(self, index_name: str):
# delete_collection here means delete index.
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
def search(self, index_name: str, vectors: list[list[float]], limit: int) -> Optional[SearchResult]:
def search(
self, index_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
@@ -94,15 +99,16 @@ class OpenSearchClient:
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {"vector": vectors[0]} # Assuming single query vector
}
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
}
},
}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}",
body=query
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_search_result(result)
@@ -112,12 +118,11 @@ class OpenSearchClient:
self._create_index(index_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]:
query = {
"query": {"match_all": {}},
"_source": ["text", "metadata"]
}
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(index=f"{self.index_prefix}_{index_name}", body=query)
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
@@ -126,7 +131,16 @@ class OpenSearchClient:
for batch in self._create_batches(items):
actions = [
{"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}}
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
@@ -137,13 +151,25 @@ class OpenSearchClient:
for batch in self._create_batches(items):
actions = [
{"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}}
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def delete(self, index_name: str, ids: list[str]):
actions = [{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} for id in ids]
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)
def reset(self):

View File

@@ -44,7 +44,9 @@ class PgvectorClient:
self.session = Session
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)

View File

@@ -15,10 +15,11 @@ class QdrantClient:
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = Qclient(
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY
) if self.QDRANT_URI else None
self.client = (
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
if self.QDRANT_URI
else None
)
def _result_to_get_result(self, points) -> GetResult:
ids = []