chore: format
This commit is contained in:
@@ -540,7 +540,7 @@ def generate_openai_batch_embeddings(
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
headers=headers,
|
||||
@@ -621,7 +621,7 @@ def generate_azure_openai_batch_embeddings(
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
|
||||
r = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -704,7 +704,7 @@ def generate_ollama_batch_embeddings(
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
|
||||
r = requests.post(
|
||||
f"{url}/api/embed",
|
||||
headers=headers,
|
||||
|
||||
@@ -10,22 +10,27 @@ from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.config import WEAVIATE_HTTP_HOST, WEAVIATE_HTTP_PORT, WEAVIATE_GRPC_PORT, WEAVIATE_API_KEY
|
||||
from open_webui.config import (
|
||||
WEAVIATE_HTTP_HOST,
|
||||
WEAVIATE_HTTP_PORT,
|
||||
WEAVIATE_GRPC_PORT,
|
||||
WEAVIATE_API_KEY,
|
||||
)
|
||||
|
||||
|
||||
def _convert_uuids_to_strings(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively convert UUID objects to strings in nested data structures.
|
||||
|
||||
|
||||
This function handles:
|
||||
- UUID objects -> string
|
||||
- Dictionaries with UUID values
|
||||
- Lists/Tuples with UUID values
|
||||
- Nested combinations of the above
|
||||
|
||||
|
||||
Args:
|
||||
obj: Any object that might contain UUIDs
|
||||
|
||||
|
||||
Returns:
|
||||
The same object structure with UUIDs converted to strings
|
||||
"""
|
||||
@@ -41,23 +46,23 @@ def _convert_uuids_to_strings(obj: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
|
||||
class WeaviateClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.url = WEAVIATE_HTTP_HOST
|
||||
try:
|
||||
# Build connection parameters
|
||||
# Build connection parameters
|
||||
connection_params = {
|
||||
"host": WEAVIATE_HTTP_HOST,
|
||||
"port": WEAVIATE_HTTP_PORT,
|
||||
"grpc_port": WEAVIATE_GRPC_PORT,
|
||||
}
|
||||
|
||||
|
||||
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
|
||||
if WEAVIATE_API_KEY:
|
||||
connection_params["auth_credentials"] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||
|
||||
connection_params["auth_credentials"] = (
|
||||
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||
)
|
||||
|
||||
self.client = weaviate.connect_to_local(**connection_params)
|
||||
self.client.connect()
|
||||
except Exception as e:
|
||||
@@ -73,16 +78,18 @@ class WeaviateClient(VectorDBBase):
|
||||
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
|
||||
|
||||
# Replace hyphens with underscores and keep only alphanumeric characters
|
||||
name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace("-", "_"))
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
|
||||
name = name.strip("_")
|
||||
|
||||
if not name:
|
||||
raise ValueError("Could not sanitize collection name to be a valid Weaviate class name")
|
||||
raise ValueError(
|
||||
"Could not sanitize collection name to be a valid Weaviate class name"
|
||||
)
|
||||
|
||||
# Ensure it starts with a letter and is capitalized
|
||||
if not name[0].isalpha():
|
||||
name = "C" + name
|
||||
|
||||
|
||||
return name[0].upper() + name[1:]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
@@ -99,8 +106,10 @@ class WeaviateClient(VectorDBBase):
|
||||
name=collection_name,
|
||||
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
|
||||
properties=[
|
||||
weaviate.classes.config.Property(name="text", data_type=weaviate.classes.config.DataType.TEXT),
|
||||
]
|
||||
weaviate.classes.config.Property(
|
||||
name="text", data_type=weaviate.classes.config.DataType.TEXT
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -109,21 +118,21 @@ class WeaviateClient(VectorDBBase):
|
||||
self._create_collection(sane_collection_name)
|
||||
|
||||
collection = self.client.collections.get(sane_collection_name)
|
||||
|
||||
|
||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||
for item in items:
|
||||
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
|
||||
|
||||
|
||||
properties = {"text": item["text"]}
|
||||
if item["metadata"]:
|
||||
clean_metadata = _convert_uuids_to_strings(process_metadata(item["metadata"]))
|
||||
clean_metadata = _convert_uuids_to_strings(
|
||||
process_metadata(item["metadata"])
|
||||
)
|
||||
clean_metadata.pop("text", None)
|
||||
properties.update(clean_metadata)
|
||||
|
||||
|
||||
batch.add_object(
|
||||
properties=properties,
|
||||
uuid=item_uuid,
|
||||
vector=item["vector"]
|
||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -132,21 +141,21 @@ class WeaviateClient(VectorDBBase):
|
||||
self._create_collection(sane_collection_name)
|
||||
|
||||
collection = self.client.collections.get(sane_collection_name)
|
||||
|
||||
|
||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||
for item in items:
|
||||
item_uuid = str(item["id"]) if item["id"] else None
|
||||
|
||||
|
||||
properties = {"text": item["text"]}
|
||||
if item["metadata"]:
|
||||
clean_metadata = _convert_uuids_to_strings(process_metadata(item["metadata"]))
|
||||
clean_metadata = _convert_uuids_to_strings(
|
||||
process_metadata(item["metadata"])
|
||||
)
|
||||
clean_metadata.pop("text", None)
|
||||
properties.update(clean_metadata)
|
||||
|
||||
|
||||
batch.add_object(
|
||||
properties=properties,
|
||||
uuid=item_uuid,
|
||||
vector=item["vector"]
|
||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -157,9 +166,14 @@ class WeaviateClient(VectorDBBase):
|
||||
return None
|
||||
|
||||
collection = self.client.collections.get(sane_collection_name)
|
||||
|
||||
result_ids, result_documents, result_metadatas, result_distances = [], [], [], []
|
||||
|
||||
|
||||
result_ids, result_documents, result_metadatas, result_distances = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
for vector_embedding in vectors:
|
||||
try:
|
||||
response = collection.query.near_vector(
|
||||
@@ -167,21 +181,28 @@ class WeaviateClient(VectorDBBase):
|
||||
limit=limit,
|
||||
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
|
||||
ids = [str(obj.uuid) for obj in response.objects]
|
||||
documents = []
|
||||
metadatas = []
|
||||
distances = []
|
||||
|
||||
|
||||
for obj in response.objects:
|
||||
properties = dict(obj.properties) if obj.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
|
||||
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
|
||||
raw_distances = [obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0 for obj in response.objects]
|
||||
raw_distances = [
|
||||
(
|
||||
obj.metadata.distance
|
||||
if obj.metadata and obj.metadata.distance
|
||||
else 2.0
|
||||
)
|
||||
for obj in response.objects
|
||||
]
|
||||
distances = [(2 - dist) / 2 for dist in raw_distances]
|
||||
|
||||
|
||||
result_ids.append(ids)
|
||||
result_documents.append(documents)
|
||||
result_metadatas.append(metadatas)
|
||||
@@ -191,7 +212,7 @@ class WeaviateClient(VectorDBBase):
|
||||
result_documents.append([])
|
||||
result_metadatas.append([])
|
||||
result_distances.append([])
|
||||
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result_ids,
|
||||
@@ -209,16 +230,26 @@ class WeaviateClient(VectorDBBase):
|
||||
return None
|
||||
|
||||
collection = self.client.collections.get(sane_collection_name)
|
||||
|
||||
|
||||
weaviate_filter = None
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
|
||||
weaviate_filter = prop_filter if weaviate_filter is None else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
|
||||
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
|
||||
value
|
||||
)
|
||||
weaviate_filter = (
|
||||
prop_filter
|
||||
if weaviate_filter is None
|
||||
else weaviate.classes.query.Filter.all_of(
|
||||
[weaviate_filter, prop_filter]
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit)
|
||||
|
||||
response = collection.query.fetch_objects(
|
||||
filters=weaviate_filter, limit=limit
|
||||
)
|
||||
|
||||
ids = [str(obj.uuid) for obj in response.objects]
|
||||
documents = []
|
||||
metadatas = []
|
||||
@@ -252,10 +283,10 @@ class WeaviateClient(VectorDBBase):
|
||||
properties = dict(item.properties) if item.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
|
||||
if not ids:
|
||||
return None
|
||||
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
@@ -285,9 +316,17 @@ class WeaviateClient(VectorDBBase):
|
||||
elif filter:
|
||||
weaviate_filter = None
|
||||
for key, value in filter.items():
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
|
||||
weaviate_filter = prop_filter if weaviate_filter is None else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
|
||||
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(
|
||||
name=key
|
||||
).equal(value)
|
||||
weaviate_filter = (
|
||||
prop_filter
|
||||
if weaviate_filter is None
|
||||
else weaviate.classes.query.Filter.all_of(
|
||||
[weaviate_filter, prop_filter]
|
||||
)
|
||||
)
|
||||
|
||||
if weaviate_filter:
|
||||
collection.data.delete_many(where=weaviate_filter)
|
||||
except Exception:
|
||||
|
||||
@@ -1025,7 +1025,9 @@ def transcription_handler(request, file_path, metadata, user=None):
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None, user=None):
|
||||
def transcribe(
|
||||
request: Request, file_path: str, metadata: Optional[dict] = None, user=None
|
||||
):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
@@ -1052,7 +1054,9 @@ def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit tasks for each chunk_path
|
||||
futures = [
|
||||
executor.submit(transcription_handler, request, chunk_path, metadata, user)
|
||||
executor.submit(
|
||||
transcription_handler, request, chunk_path, metadata, user
|
||||
)
|
||||
for chunk_path in chunk_paths
|
||||
]
|
||||
# Gather results as they complete
|
||||
|
||||
@@ -52,9 +52,7 @@ async def add_memory(
|
||||
):
|
||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
)
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
@@ -112,10 +110,12 @@ async def reset_memory_from_vector_db(
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
|
||||
# Generate vectors in parallel
|
||||
vectors = await asyncio.gather(*[
|
||||
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||
for memory in memories
|
||||
])
|
||||
vectors = await asyncio.gather(
|
||||
*[
|
||||
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||
for memory in memories
|
||||
]
|
||||
)
|
||||
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
@@ -174,9 +174,7 @@ async def update_memory_by_id(
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
if form_data.content is not None:
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(
|
||||
memory.content, user=user
|
||||
)
|
||||
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=f"user-memory-{user.id}",
|
||||
|
||||
@@ -53,7 +53,6 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
||||
|
||||
from open_webui.config import (
|
||||
UPLOAD_DIR,
|
||||
)
|
||||
|
||||
@@ -1468,11 +1468,13 @@ def save_docs_to_vector_db(
|
||||
)
|
||||
|
||||
# Run async embedding in sync context
|
||||
embeddings = asyncio.run(embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts)),
|
||||
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
||||
user=user,
|
||||
))
|
||||
embeddings = asyncio.run(
|
||||
embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts)),
|
||||
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
|
||||
|
||||
items = [
|
||||
|
||||
Reference in New Issue
Block a user