mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'upstream-dev' into dev
This commit is contained in:
@@ -63,7 +63,7 @@ from open_webui.config import (
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
RAG_EMBEDDING_BATCH_SIZE,
|
||||
RAG_FILE_MAX_COUNT,
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -267,7 +267,7 @@ async def get_status():
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
}
|
||||
|
||||
|
||||
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"status": True,
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
|
||||
|
||||
@app.post("/embedding/update")
|
||||
@@ -320,11 +320,7 @@ async def update_embedding_config(
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
||||
form_data.openai_config.batch_size
|
||||
if form_data.openai_config.batch_size
|
||||
else 1
|
||||
)
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
@@ -334,17 +330,17 @@ async def update_embedding_config(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -645,7 +641,7 @@ def save_docs_to_vector_db(
|
||||
filter={"hash": metadata["hash"]},
|
||||
)
|
||||
|
||||
if result:
|
||||
if result is not None:
|
||||
existing_doc_ids = result.ids[0]
|
||||
if existing_doc_ids:
|
||||
log.info(f"Document with hash {metadata['hash']} already exists")
|
||||
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
@@ -767,7 +763,7 @@ def process_file(
|
||||
collection_name=f"file-{file.id}", filter={"file_id": file.id}
|
||||
)
|
||||
|
||||
if len(result.ids[0]) > 0:
|
||||
if result is not None and len(result.ids[0]) > 0:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result.documents[0][idx],
|
||||
|
||||
@@ -12,8 +12,8 @@ from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbedForm,
|
||||
generate_ollama_batch_embeddings,
|
||||
)
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
@@ -193,7 +193,8 @@ def query_collection(
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
results.append(result.model_dump())
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
@@ -265,39 +266,27 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
if embedding_engine == "ollama":
|
||||
func = lambda query: generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
func = lambda query: generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
func = lambda query: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key if embedding_engine == "openai" else "",
|
||||
url=openai_url if embedding_engine == "openai" else "",
|
||||
)
|
||||
|
||||
def generate_multiple(query, f):
|
||||
def generate_multiple(query, func):
|
||||
if isinstance(query, list):
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return f(query)
|
||||
return func(query)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
@@ -445,20 +434,6 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
key: str,
|
||||
url: str = "https://api.openai.com/v1",
|
||||
):
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
@@ -482,6 +457,33 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": text})
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": [text]})
|
||||
)
|
||||
return (
|
||||
embeddings["embeddings"][0]
|
||||
if isinstance(text, str)
|
||||
else embeddings["embeddings"]
|
||||
)
|
||||
elif engine == "openai":
|
||||
key = kwargs.get("key", "")
|
||||
url = kwargs.get("url", "https://api.openai.com/v1")
|
||||
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
|
||||
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
else:
|
||||
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
|
||||
176
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
Normal file
176
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(
|
||||
collection_name=collection_name
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
|
||||
def _create_points(self, items: list[VectorItem]):
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"]
|
||||
},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
query_response = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query=vectors[0],
|
||||
limit=limit,
|
||||
)
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]]
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
try:
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
|
||||
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
field_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
),
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
),
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(
|
||||
must=field_conditions
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
Reference in New Issue
Block a user