From f3e5700d49d5c7fe609aa16530b1b5d83ae10b90 Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Mon, 22 Apr 2024 13:27:43 -0500 Subject: [PATCH] feat: move to native sentence_transformer --- CHANGELOG.md | 6 + Dockerfile | 12 +- backend/apps/rag/main.py | 206 ++++++++---------- backend/apps/rag/utils.py | 182 ++++------------ backend/config.py | 11 +- backend/requirements.txt | 2 + .../documents/Settings/General.svelte | 2 +- 7 files changed, 153 insertions(+), 268 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dad583399..1eaffc692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.121] - 2024-04-22 + +### Added + +- **🛠️ Improved Embedding Model Support**: You can now use any embedding model `sentence_transformers` supports. + ## [0.1.120] - 2024-04-20 ### Added diff --git a/Dockerfile b/Dockerfile index f19952909..a8f664ada 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) -# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. -ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2 +# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. +ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ######## WebUI frontend ######## FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build @@ -98,13 +98,13 @@ RUN pip3 install uv && \ # If you use CUDA the whisper and embedding model will be downloaded on first use pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ - python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ - python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ fi diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index ac8410dbe..5da7489f1 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,7 +13,6 @@ import os, shutil, logging, re from pathlib import Path from typing import List -from chromadb.utils import embedding_functions from chromadb.utils.batch_utils import create_batches from langchain_community.document_loaders import ( @@ -38,6 +37,7 @@ import mimetypes import uuid import json +import sentence_transformers from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm @@ -48,11 +48,8 @@ from apps.web.models.documents import ( ) from apps.rag.utils import ( - query_doc, query_embeddings_doc, - query_collection, query_embeddings_collection, - get_embedding_model_path, generate_openai_embeddings, ) @@ -69,7 +66,7 @@ from config import ( DOCS_DIR, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, DEVICE_TYPE, @@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.PDF_EXTRACT_IMAGES = False - -app.state.sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=get_embedding_model_path( - app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE - ), +if app.state.RAG_EMBEDDING_ENGINE == "": + app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( + app.state.RAG_EMBEDDING_MODEL, device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) -) origins = ["*"] @@ -185,13 +179,10 @@ async def update_embedding_config( app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_KEY = form_data.openai_config.key else: - sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=get_embedding_model_path( - form_data.embedding_model, True - ), - device=DEVICE_TYPE, - ) + sentence_transformer_ef = sentence_transformers.SentenceTransformer( + app.state.RAG_EMBEDDING_MODEL, + device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.sentence_transformer_ef = sentence_transformer_ef @@ -294,38 +285,34 @@ def query_doc_handler( form_data: QueryDocForm, user=Depends(get_current_user), ): - try: if app.state.RAG_EMBEDDING_ENGINE == "": - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, + query_embeddings = app.state.sentence_transformer_ef.encode( + form_data.query + ).tolist() + elif app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) + ) + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + query_embeddings = generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=form_data.query, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, ) - else: - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": app.state.RAG_EMBEDDING_MODEL, - "prompt": form_data.query, - } - ) - ) - elif app.state.RAG_EMBEDDING_ENGINE == "openai": - query_embeddings = generate_openai_embeddings( - model=app.state.RAG_EMBEDDING_MODEL, - text=form_data.query, - key=app.state.OPENAI_API_KEY, - url=app.state.OPENAI_API_BASE_URL, - ) - return query_embeddings_doc( - collection_name=form_data.collection_name, - query_embeddings=query_embeddings, - k=form_data.k if form_data.k else app.state.TOP_K, - ) + return query_embeddings_doc( + collection_name=form_data.collection_name, + query=form_data.query, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) except Exception as e: log.exception(e) @@ -348,36 +335,31 @@ def query_collection_handler( ): try: if app.state.RAG_EMBEDDING_ENGINE == "": - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, - ) - else: - - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": app.state.RAG_EMBEDDING_MODEL, - "prompt": form_data.query, - } - ) + query_embeddings = app.state.sentence_transformer_ef.encode( + form_data.query + ).tolist() + elif app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } ) - elif app.state.RAG_EMBEDDING_ENGINE == "openai": - query_embeddings = generate_openai_embeddings( - model=app.state.RAG_EMBEDDING_MODEL, - text=form_data.query, - key=app.state.OPENAI_API_KEY, - url=app.state.OPENAI_API_BASE_URL, - ) - - return query_embeddings_collection( - collection_names=form_data.collection_names, - query_embeddings=query_embeddings, - k=form_data.k if form_data.k else app.state.TOP_K, ) + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + query_embeddings = generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=form_data.query, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, + ) + + return query_embeddings_collection( + collection_names=form_data.collection_names, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) except Exception as e: log.exception(e) @@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"store_docs_in_vector_db {docs} {collection_name}") texts = [doc.page_content for doc in docs] + texts = list(map(lambda x: x.replace("\n", " "), texts)) + metadatas = [doc.metadata for doc in docs] try: @@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"deleting existing collection {collection_name}") CHROMA_CLIENT.delete_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection(name=collection_name) + if app.state.RAG_EMBEDDING_ENGINE == "": - - collection = CHROMA_CLIENT.create_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) - - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - documents=texts, - ): - collection.add(*batch) - - else: - collection = CHROMA_CLIENT.create_collection(name=collection_name) - - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - embeddings = [ - generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} - ) + embeddings = app.state.sentence_transformer_ef.encode(texts).tolist() + elif app.state.RAG_EMBEDDING_ENGINE == "ollama": + embeddings = [ + generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} ) - for text in texts - ] - elif app.state.RAG_EMBEDDING_ENGINE == "openai": - embeddings = [ - generate_openai_embeddings( - model=app.state.RAG_EMBEDDING_MODEL, - text=text, - key=app.state.OPENAI_API_KEY, - url=app.state.OPENAI_API_BASE_URL, - ) - for text in texts - ] + ) + for text in texts + ] + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + embeddings = [ + generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=text, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, + ) + for text in texts + ] - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - embeddings=embeddings, - documents=texts, - ): - collection.add(*batch) + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + embeddings=embeddings, + documents=texts, + ): + collection.add(*batch) return True except Exception as e: diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index f4d1246c7..0ce299279 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,13 +1,12 @@ -import os -import re import logging -from typing import List import requests +from typing import List -from huggingface_hub import snapshot_download -from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm - +from apps.ollama.main import ( + generate_ollama_embeddings, + GenerateEmbeddingsForm, +) from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -16,29 +15,12 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def query_doc(collection_name: str, query: str, k: int, embedding_function): - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=embedding_function, - ) - result = collection.query( - query_texts=[query], - n_results=k, - ) - return result - except Exception as e: - raise e - - -def query_embeddings_doc(collection_name: str, query_embeddings, k: int): +def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int): try: # if you use docker use the model from the environment variable log.info(f"query_embeddings_doc {query_embeddings}") - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - ) + collection = CHROMA_CLIENT.get_collection(name=collection_name) + result = collection.query( query_embeddings=[query_embeddings], n_results=k, @@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k): return merged_query_results -def query_collection( - collection_names: List[str], query: str, k: int, embedding_function +def query_embeddings_collection( + collection_names: List[str], query: str, query_embeddings, k: int ): - results = [] - - for collection_name in collection_names: - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=embedding_function, - ) - - result = collection.query( - query_texts=[query], - n_results=k, - ) - results.append(result) - except: - pass - - return merge_and_sort_query_results(results, k) - - -def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): - results = [] log.info(f"query_embeddings_collection {query_embeddings}") for collection_name in collection_names: try: - collection = CHROMA_CLIENT.get_collection(name=collection_name) - - result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, + result = query_embeddings_doc( + collection_name=collection_name, + query=query, + query_embeddings=query_embeddings, + k=k, ) results.append(result) except: @@ -197,51 +156,38 @@ def rag_messages( context = doc["content"] else: if embedding_engine == "": - if doc["type"] == "collection": - context = query_collection( - collection_names=doc["collection_names"], - query=query, - k=k, - embedding_function=embedding_function, - ) - else: - context = query_doc( - collection_name=doc["collection_name"], - query=query, - k=k, - embedding_function=embedding_function, + query_embeddings = embedding_function.encode(query).tolist() + elif embedding_engine == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": embedding_model, + "prompt": query, + } ) + ) + elif embedding_engine == "openai": + query_embeddings = generate_openai_embeddings( + model=embedding_model, + text=query, + key=openai_key, + url=openai_url, + ) + if doc["type"] == "collection": + context = query_embeddings_collection( + collection_names=doc["collection_names"], + query=query, + query_embeddings=query_embeddings, + k=k, + ) else: - if embedding_engine == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": embedding_model, - "prompt": query, - } - ) - ) - elif embedding_engine == "openai": - query_embeddings = generate_openai_embeddings( - model=embedding_model, - text=query, - key=openai_key, - url=openai_url, - ) - - if doc["type"] == "collection": - context = query_embeddings_collection( - collection_names=doc["collection_names"], - query_embeddings=query_embeddings, - k=k, - ) - else: - context = query_embeddings_doc( - collection_name=doc["collection_name"], - query_embeddings=query_embeddings, - k=k, - ) + context = query_embeddings_doc( + collection_name=doc["collection_name"], + query=query, + query_embeddings=query_embeddings, + k=k, + ) except Exception as e: log.exception(e) @@ -283,46 +229,6 @@ def rag_messages( return messages -def get_embedding_model_path( - embedding_model: str, update_embedding_model: bool = False -): - # Construct huggingface_hub kwargs with local_files_only to return the snapshot path - cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") - - local_files_only = not update_embedding_model - - snapshot_kwargs = { - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - - log.debug(f"embedding_model: {embedding_model}") - log.debug(f"snapshot_kwargs: {snapshot_kwargs}") - - # Inspiration from upstream sentence_transformers - if ( - os.path.exists(embedding_model) - or ("\\" in embedding_model or embedding_model.count("/") > 1) - and local_files_only - ): - # If fully qualified path exists, return input, else set repo_id - return embedding_model - elif "/" not in embedding_model: - # Set valid repo_id for model short-name - embedding_model = "sentence-transformers" + "/" + embedding_model - - snapshot_kwargs["repo_id"] = embedding_model - - # Attempt to query the huggingface_hub library to determine the local path and/or to update - try: - embedding_model_repo_path = snapshot_download(**snapshot_kwargs) - log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") - return embedding_model_repo_path - except Exception as e: - log.exception(f"Cannot determine embedding model snapshot path: {e}") - return embedding_model - - def generate_openai_embeddings( model: str, text: str, key: str, url: str = "https://api.openai.com/v1" ): diff --git a/backend/config.py b/backend/config.py index 6ca2c67bf..17f8f91bf 100644 --- a/backend/config.py +++ b/backend/config.py @@ -411,18 +411,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": #################################### CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" -# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) +# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") -RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") +RAG_EMBEDDING_MODEL = os.environ.get( + "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" +) log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), -RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" +RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( + os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) - # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") diff --git a/backend/requirements.txt b/backend/requirements.txt index c815d93da..d5c179d86 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -25,6 +25,7 @@ apscheduler google-generativeai langchain +langchain-chroma langchain-community fake_useragent chromadb @@ -43,6 +44,7 @@ opencv-python-headless rapidocr-onnxruntime fpdf2 +rank_bm25 faster-whisper diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index a2bbec852..6abdda5af 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -180,7 +180,7 @@ } }} > - +