From 1846c1e80dc597d83ad70759742abce67884c0e0 Mon Sep 17 00:00:00 2001 From: Jannik Streidl Date: Sat, 17 Feb 2024 19:38:29 +0100 Subject: [PATCH] choose embedding model when using docker --- Dockerfile | 12 ++++++++-- backend/apps/rag/main.py | 51 ++++++++++++++++++++++++++-------------- backend/config.py | 3 ++- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/Dockerfile b/Dockerfile index 520c2964d..722303483 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY "" ENV SCARF_NO_ANALYTICS true ENV DO_NOT_TRACK true -#Whisper TTS Settings +# whisper TTS Settings ENV WHISPER_MODEL="base" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" +# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers +# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard +# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" +# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. +ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2" + WORKDIR /app/backend # install python dependencies @@ -48,7 +54,9 @@ RUN apt-get update \ && apt-get install -y pandoc netcat-openbsd \ && rm -rf /var/lib/apt/lists/* -# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" +# preload embedding model +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])" +# preload tts model RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 07a30adee..defe10f95 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,6 +1,5 @@ from fastapi import ( FastAPI, - Request, Depends, HTTPException, status, @@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware import os, shutil from typing import List -# from chromadb.utils import embedding_functions +from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( WebBaseLoader, @@ -28,24 +27,19 @@ from langchain_community.document_loaders import ( UnstructuredExcelLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.vectorstores import Chroma -from langchain.chains import RetrievalQA from pydantic import BaseModel from typing import Optional import uuid -import time from utils.misc import calculate_sha256, calculate_sha256_string from utils.utils import get_current_user, get_admin_user -from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES -# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( -# model_name=EMBED_MODEL -# ) +sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) app = FastAPI() @@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - collection = CHROMA_CLIENT.create_collection(name=collection_name) + if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) + + else: + # for local development use the default model + collection = CHROMA_CLIENT.create_collection(name=collection_name) collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) return True except Exception as e: print(e) @@ -109,9 +109,17 @@ def query_doc( user=Depends(get_current_user), ): try: - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - ) + if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + embedding_function=sentence_transformer_ef + ) + else: + # for local development use the default model + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: @@ -182,9 +190,18 @@ def query_collection( for collection_name in form_data.collection_names: try: - collection = CHROMA_CLIENT.get_collection( - name=collection_name, + if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + embedding_function=sentence_transformer_ef + ) + else: + # for local development use the default model + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, ) + result = collection.query( query_texts=[form_data.query], n_results=form_data.k ) diff --git a/backend/config.py b/backend/config.py index d7c89b3ba..023954a4d 100644 --- a/backend/config.py +++ b/backend/config.py @@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": #################################### CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" -EMBED_MODEL = "all-MiniLM-L6-v2" +# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) +SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL") CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False),