Merge branch 'upstream-dev' into dev

This commit is contained in:
Jannik Streidl
2024-10-12 15:18:59 +02:00
103 changed files with 2956 additions and 3023 deletions

View File

@@ -63,7 +63,7 @@ from open_webui.config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
app.add_middleware(
@@ -267,7 +267,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
}
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@app.post("/embedding/update")
@@ -320,11 +320,7 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -334,17 +330,17 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -645,7 +641,7 @@ def save_docs_to_vector_db(
filter={"hash": metadata["hash"]},
)
if result:
if result is not None:
existing_doc_ids = result.ids[0]
if existing_doc_ids:
log.info(f"Document with hash {metadata['hash']} already exists")
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
embeddings = embedding_function(
@@ -767,7 +763,7 @@ def process_file(
collection_name=f"file-{file.id}", filter={"file_id": file.id}
)
if len(result.ids[0]) > 0:
if result is not None and len(result.ids[0]) > 0:
docs = [
Document(
page_content=result.documents[0][idx],

View File

@@ -12,8 +12,8 @@ from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@@ -193,7 +193,8 @@ def query_collection(
k=k,
query_embedding=query_embedding,
)
results.append(result.model_dump())
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
@@ -265,39 +266,27 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
batch_size,
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
func = lambda query: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
)
def generate_multiple(query, f):
def generate_multiple(query, func):
if isinstance(query, list):
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings
else:
return f(query)
return func(query)
return lambda query: generate_multiple(query, func)
@@ -445,20 +434,6 @@ def get_model_path(model: str, update_model: bool = False):
return model
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
@@ -482,6 +457,33 @@ def generate_openai_batch_embeddings(
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence

View File

@@ -4,6 +4,10 @@ if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient

View File

@@ -0,0 +1,176 @@
from typing import Optional
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _create_collection(self, collection_name: str, dimension: int):
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
)
print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(
collection_name=collection_name
):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
def _create_points(self, items: list[VectorItem]):
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"]
},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
def delete_collection(self, collection_name: str):
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query=vectors[0],
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]]
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
if not self.has_collection(collection_name):
return None
try:
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query_filter=models.Filter(should=field_conditions),
limit=limit,
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
if ids:
for id_value in ids:
field_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
),
elif filter:
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
),
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(
must=field_conditions
)
),
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)