diff --git a/CHANGELOG.md b/CHANGELOG.md
index dad583399..1eaffc692 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.1.121] - 2024-04-22
+
+### Added
+
+- **🛠️ Improved Embedding Model Support**: You can now use any embedding model `sentence_transformers` supports.
+
## [0.1.120] - 2024-04-20
### Added
diff --git a/Dockerfile b/Dockerfile
index f19952909..a8f664ada 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
-# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
-ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2
+# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
+ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
@@ -98,13 +98,13 @@ RUN pip3 install uv && \
# If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
- python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
- python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
- python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
- python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
fi
diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index ac8410dbe..5da7489f1 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -13,7 +13,6 @@ import os, shutil, logging, re
from pathlib import Path
from typing import List
-from chromadb.utils import embedding_functions
from chromadb.utils.batch_utils import create_batches
from langchain_community.document_loaders import (
@@ -38,6 +37,7 @@ import mimetypes
import uuid
import json
+import sentence_transformers
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
@@ -48,11 +48,8 @@ from apps.web.models.documents import (
)
from apps.rag.utils import (
- query_doc,
query_embeddings_doc,
- query_collection,
query_embeddings_collection,
- get_embedding_model_path,
generate_openai_embeddings,
)
@@ -69,7 +66,7 @@ from config import (
DOCS_DIR,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL,
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
+ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
DEVICE_TYPE,
@@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = False
-
-app.state.sentence_transformer_ef = (
- embedding_functions.SentenceTransformerEmbeddingFunction(
- model_name=get_embedding_model_path(
- app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
- ),
+if app.state.RAG_EMBEDDING_ENGINE == "":
+ app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
+ app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE,
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
-)
origins = ["*"]
@@ -185,13 +179,10 @@ async def update_embedding_config(
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key
else:
- sentence_transformer_ef = (
- embedding_functions.SentenceTransformerEmbeddingFunction(
- model_name=get_embedding_model_path(
- form_data.embedding_model, True
- ),
- device=DEVICE_TYPE,
- )
+ sentence_transformer_ef = sentence_transformers.SentenceTransformer(
+ app.state.RAG_EMBEDDING_MODEL,
+ device=DEVICE_TYPE,
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef
@@ -294,38 +285,34 @@ def query_doc_handler(
form_data: QueryDocForm,
user=Depends(get_current_user),
):
-
try:
if app.state.RAG_EMBEDDING_ENGINE == "":
- return query_doc(
- collection_name=form_data.collection_name,
- query=form_data.query,
- k=form_data.k if form_data.k else app.state.TOP_K,
- embedding_function=app.state.sentence_transformer_ef,
+ query_embeddings = app.state.sentence_transformer_ef.encode(
+ form_data.query
+ ).tolist()
+ elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
+ query_embeddings = generate_ollama_embeddings(
+ GenerateEmbeddingsForm(
+ **{
+ "model": app.state.RAG_EMBEDDING_MODEL,
+ "prompt": form_data.query,
+ }
+ )
+ )
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+ query_embeddings = generate_openai_embeddings(
+ model=app.state.RAG_EMBEDDING_MODEL,
+ text=form_data.query,
+ key=app.state.OPENAI_API_KEY,
+ url=app.state.OPENAI_API_BASE_URL,
)
- else:
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
- query_embeddings = generate_ollama_embeddings(
- GenerateEmbeddingsForm(
- **{
- "model": app.state.RAG_EMBEDDING_MODEL,
- "prompt": form_data.query,
- }
- )
- )
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
- query_embeddings = generate_openai_embeddings(
- model=app.state.RAG_EMBEDDING_MODEL,
- text=form_data.query,
- key=app.state.OPENAI_API_KEY,
- url=app.state.OPENAI_API_BASE_URL,
- )
- return query_embeddings_doc(
- collection_name=form_data.collection_name,
- query_embeddings=query_embeddings,
- k=form_data.k if form_data.k else app.state.TOP_K,
- )
+ return query_embeddings_doc(
+ collection_name=form_data.collection_name,
+ query=form_data.query,
+ query_embeddings=query_embeddings,
+ k=form_data.k if form_data.k else app.state.TOP_K,
+ )
except Exception as e:
log.exception(e)
@@ -348,36 +335,31 @@ def query_collection_handler(
):
try:
if app.state.RAG_EMBEDDING_ENGINE == "":
- return query_collection(
- collection_names=form_data.collection_names,
- query=form_data.query,
- k=form_data.k if form_data.k else app.state.TOP_K,
- embedding_function=app.state.sentence_transformer_ef,
- )
- else:
-
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
- query_embeddings = generate_ollama_embeddings(
- GenerateEmbeddingsForm(
- **{
- "model": app.state.RAG_EMBEDDING_MODEL,
- "prompt": form_data.query,
- }
- )
+ query_embeddings = app.state.sentence_transformer_ef.encode(
+ form_data.query
+ ).tolist()
+ elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
+ query_embeddings = generate_ollama_embeddings(
+ GenerateEmbeddingsForm(
+ **{
+ "model": app.state.RAG_EMBEDDING_MODEL,
+ "prompt": form_data.query,
+ }
)
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
- query_embeddings = generate_openai_embeddings(
- model=app.state.RAG_EMBEDDING_MODEL,
- text=form_data.query,
- key=app.state.OPENAI_API_KEY,
- url=app.state.OPENAI_API_BASE_URL,
- )
-
- return query_embeddings_collection(
- collection_names=form_data.collection_names,
- query_embeddings=query_embeddings,
- k=form_data.k if form_data.k else app.state.TOP_K,
)
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+ query_embeddings = generate_openai_embeddings(
+ model=app.state.RAG_EMBEDDING_MODEL,
+ text=form_data.query,
+ key=app.state.OPENAI_API_KEY,
+ url=app.state.OPENAI_API_BASE_URL,
+ )
+
+ return query_embeddings_collection(
+ collection_names=form_data.collection_names,
+ query_embeddings=query_embeddings,
+ k=form_data.k if form_data.k else app.state.TOP_K,
+ )
except Exception as e:
log.exception(e)
@@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs]
+ texts = list(map(lambda x: x.replace("\n", " "), texts))
+
metadatas = [doc.metadata for doc in docs]
try:
@@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
+ collection = CHROMA_CLIENT.create_collection(name=collection_name)
+
if app.state.RAG_EMBEDDING_ENGINE == "":
-
- collection = CHROMA_CLIENT.create_collection(
- name=collection_name,
- embedding_function=app.state.sentence_transformer_ef,
- )
-
- for batch in create_batches(
- api=CHROMA_CLIENT,
- ids=[str(uuid.uuid1()) for _ in texts],
- metadatas=metadatas,
- documents=texts,
- ):
- collection.add(*batch)
-
- else:
- collection = CHROMA_CLIENT.create_collection(name=collection_name)
-
- if app.state.RAG_EMBEDDING_ENGINE == "ollama":
- embeddings = [
- generate_ollama_embeddings(
- GenerateEmbeddingsForm(
- **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
- )
+ embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
+ elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
+ embeddings = [
+ generate_ollama_embeddings(
+ GenerateEmbeddingsForm(
+ **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
- for text in texts
- ]
- elif app.state.RAG_EMBEDDING_ENGINE == "openai":
- embeddings = [
- generate_openai_embeddings(
- model=app.state.RAG_EMBEDDING_MODEL,
- text=text,
- key=app.state.OPENAI_API_KEY,
- url=app.state.OPENAI_API_BASE_URL,
- )
- for text in texts
- ]
+ )
+ for text in texts
+ ]
+ elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+ embeddings = [
+ generate_openai_embeddings(
+ model=app.state.RAG_EMBEDDING_MODEL,
+ text=text,
+ key=app.state.OPENAI_API_KEY,
+ url=app.state.OPENAI_API_BASE_URL,
+ )
+ for text in texts
+ ]
- for batch in create_batches(
- api=CHROMA_CLIENT,
- ids=[str(uuid.uuid1()) for _ in texts],
- metadatas=metadatas,
- embeddings=embeddings,
- documents=texts,
- ):
- collection.add(*batch)
+ for batch in create_batches(
+ api=CHROMA_CLIENT,
+ ids=[str(uuid.uuid1()) for _ in texts],
+ metadatas=metadatas,
+ embeddings=embeddings,
+ documents=texts,
+ ):
+ collection.add(*batch)
return True
except Exception as e:
diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py
index f4d1246c7..0ce299279 100644
--- a/backend/apps/rag/utils.py
+++ b/backend/apps/rag/utils.py
@@ -1,13 +1,12 @@
-import os
-import re
import logging
-from typing import List
import requests
+from typing import List
-from huggingface_hub import snapshot_download
-from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
-
+from apps.ollama.main import (
+ generate_ollama_embeddings,
+ GenerateEmbeddingsForm,
+)
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@@ -16,29 +15,12 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
-def query_doc(collection_name: str, query: str, k: int, embedding_function):
- try:
- # if you use docker use the model from the environment variable
- collection = CHROMA_CLIENT.get_collection(
- name=collection_name,
- embedding_function=embedding_function,
- )
- result = collection.query(
- query_texts=[query],
- n_results=k,
- )
- return result
- except Exception as e:
- raise e
-
-
-def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
+def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}")
- collection = CHROMA_CLIENT.get_collection(
- name=collection_name,
- )
+ collection = CHROMA_CLIENT.get_collection(name=collection_name)
+
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
@@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k):
return merged_query_results
-def query_collection(
- collection_names: List[str], query: str, k: int, embedding_function
+def query_embeddings_collection(
+ collection_names: List[str], query: str, query_embeddings, k: int
):
- results = []
-
- for collection_name in collection_names:
- try:
- # if you use docker use the model from the environment variable
- collection = CHROMA_CLIENT.get_collection(
- name=collection_name,
- embedding_function=embedding_function,
- )
-
- result = collection.query(
- query_texts=[query],
- n_results=k,
- )
- results.append(result)
- except:
- pass
-
- return merge_and_sort_query_results(results, k)
-
-
-def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
-
results = []
log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names:
try:
- collection = CHROMA_CLIENT.get_collection(name=collection_name)
-
- result = collection.query(
- query_embeddings=[query_embeddings],
- n_results=k,
+ result = query_embeddings_doc(
+ collection_name=collection_name,
+ query=query,
+ query_embeddings=query_embeddings,
+ k=k,
)
results.append(result)
except:
@@ -197,51 +156,38 @@ def rag_messages(
context = doc["content"]
else:
if embedding_engine == "":
- if doc["type"] == "collection":
- context = query_collection(
- collection_names=doc["collection_names"],
- query=query,
- k=k,
- embedding_function=embedding_function,
- )
- else:
- context = query_doc(
- collection_name=doc["collection_name"],
- query=query,
- k=k,
- embedding_function=embedding_function,
+ query_embeddings = embedding_function.encode(query).tolist()
+ elif embedding_engine == "ollama":
+ query_embeddings = generate_ollama_embeddings(
+ GenerateEmbeddingsForm(
+ **{
+ "model": embedding_model,
+ "prompt": query,
+ }
)
+ )
+ elif embedding_engine == "openai":
+ query_embeddings = generate_openai_embeddings(
+ model=embedding_model,
+ text=query,
+ key=openai_key,
+ url=openai_url,
+ )
+ if doc["type"] == "collection":
+ context = query_embeddings_collection(
+ collection_names=doc["collection_names"],
+ query=query,
+ query_embeddings=query_embeddings,
+ k=k,
+ )
else:
- if embedding_engine == "ollama":
- query_embeddings = generate_ollama_embeddings(
- GenerateEmbeddingsForm(
- **{
- "model": embedding_model,
- "prompt": query,
- }
- )
- )
- elif embedding_engine == "openai":
- query_embeddings = generate_openai_embeddings(
- model=embedding_model,
- text=query,
- key=openai_key,
- url=openai_url,
- )
-
- if doc["type"] == "collection":
- context = query_embeddings_collection(
- collection_names=doc["collection_names"],
- query_embeddings=query_embeddings,
- k=k,
- )
- else:
- context = query_embeddings_doc(
- collection_name=doc["collection_name"],
- query_embeddings=query_embeddings,
- k=k,
- )
+ context = query_embeddings_doc(
+ collection_name=doc["collection_name"],
+ query=query,
+ query_embeddings=query_embeddings,
+ k=k,
+ )
except Exception as e:
log.exception(e)
@@ -283,46 +229,6 @@ def rag_messages(
return messages
-def get_embedding_model_path(
- embedding_model: str, update_embedding_model: bool = False
-):
- # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
- cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
-
- local_files_only = not update_embedding_model
-
- snapshot_kwargs = {
- "cache_dir": cache_dir,
- "local_files_only": local_files_only,
- }
-
- log.debug(f"embedding_model: {embedding_model}")
- log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
-
- # Inspiration from upstream sentence_transformers
- if (
- os.path.exists(embedding_model)
- or ("\\" in embedding_model or embedding_model.count("/") > 1)
- and local_files_only
- ):
- # If fully qualified path exists, return input, else set repo_id
- return embedding_model
- elif "/" not in embedding_model:
- # Set valid repo_id for model short-name
- embedding_model = "sentence-transformers" + "/" + embedding_model
-
- snapshot_kwargs["repo_id"] = embedding_model
-
- # Attempt to query the huggingface_hub library to determine the local path and/or to update
- try:
- embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
- log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
- return embedding_model_repo_path
- except Exception as e:
- log.exception(f"Cannot determine embedding model snapshot path: {e}")
- return embedding_model
-
-
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
):
diff --git a/backend/config.py b/backend/config.py
index 6ca2c67bf..17f8f91bf 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -411,18 +411,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
-# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
+# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
-RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
+RAG_EMBEDDING_MODEL = os.environ.get(
+ "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
+)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
-RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
- os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
+RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
+ os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
-
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
diff --git a/backend/requirements.txt b/backend/requirements.txt
index c815d93da..d5c179d86 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -25,6 +25,7 @@ apscheduler
google-generativeai
langchain
+langchain-chroma
langchain-community
fake_useragent
chromadb
@@ -43,6 +44,7 @@ opencv-python-headless
rapidocr-onnxruntime
fpdf2
+rank_bm25
faster-whisper
diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte
index a2bbec852..6abdda5af 100644
--- a/src/lib/components/documents/Settings/General.svelte
+++ b/src/lib/components/documents/Settings/General.svelte
@@ -180,7 +180,7 @@
}
}}
>
-
+