From 784b369cc9279c8249da968d2f8dcefe7951bf9a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 6 Jan 2024 22:59:22 -0800 Subject: [PATCH] feat: chromadb vector store api --- backend/.gitignore | 3 +- backend/apps/rag/main.py | 100 +++++++++++++++++++++++++++++++++++++-- backend/config.py | 21 ++++++-- backend/constants.py | 6 ++- 4 files changed, 119 insertions(+), 11 deletions(-) diff --git a/backend/.gitignore b/backend/.gitignore index da641cf7d..62a3a06a0 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -5,4 +5,5 @@ uploads .ipynb_checkpoints *.db _test -Pipfile \ No newline at end of file +Pipfile +data/* \ No newline at end of file diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6870792de..7dae9bc24 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,9 +1,25 @@ -from fastapi import FastAPI, Request, Depends, HTTPException +from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import auths, users, chats, modelfiles, utils -from config import WEBUI_VERSION, WEBUI_AUTH +from chromadb.utils import embedding_functions +from langchain.document_loaders import WebBaseLoader, TextLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain.chains import RetrievalQA + + +from pydantic import BaseModel +from typing import Optional + +import uuid + +from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from constants import ERROR_MESSAGES + +EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=EMBED_MODEL +) app = FastAPI() @@ -18,6 +34,84 @@ app.add_middleware( ) +class StoreWebForm(BaseModel): + url: str + collection_name: Optional[str] = "test" + + +def store_data_in_vector_db(data, collection_name): + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP + ) + docs = text_splitter.split_documents(data) + + texts = [doc.page_content for doc in docs] + metadatas = [doc.metadata for doc in docs] + + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=EMBEDDING_FUNC + ) + + collection.add( + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) + + @app.get("/") async def get_status(): return {"status": True} + + +@app.get("/query/{collection_name}") +def query_collection(collection_name: str, query: str, k: Optional[int] = 4): + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query(query_texts=[query], n_results=k) + + return result + + +@app.post("/web") +def store_web(form_data: StoreWebForm): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + try: + loader = WebBaseLoader(form_data.url) + data = loader.load() + store_data_in_vector_db(data, form_data.collection_name) + return {"status": True} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@app.post("/doc") +def store_doc(file: UploadFile = File(...)): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + + try: + print(file) + file.filename = f"{uuid.uuid4()}-{file.filename}" + contents = file.file.read() + with open(f"./data/{file.filename}", "wb") as f: + f.write(contents) + f.close() + + # loader = WebBaseLoader(form_data.url) + # data = loader.load() + # store_data_in_vector_db(data, form_data.collection_name) + return {"status": True} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +def reset_vector_db(): + CHROMA_CLIENT.reset() + return {"status": True} diff --git a/backend/config.py b/backend/config.py index 4c518d139..df57c8292 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,11 +1,11 @@ from dotenv import load_dotenv, find_dotenv - -from constants import ERROR_MESSAGES +import os +import chromadb from secrets import token_bytes from base64 import b64encode -import os +from constants import ERROR_MESSAGES load_dotenv(find_dotenv("../.env")) @@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev") # OLLAMA_API_BASE_URL #################################### -OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL", - "http://localhost:11434/api") +OLLAMA_API_BASE_URL = os.environ.get( + "OLLAMA_API_BASE_URL", "http://localhost:11434/api" +) if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": @@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +#################################### +# RAG +#################################### + +CHROMA_DATA_PATH = "./data/vector_db" +EMBED_MODEL = "all-MiniLM-L6-v2" +CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH) +CHUNK_SIZE = 1500 +CHUNK_OVERLAP = 100 diff --git a/backend/constants.py b/backend/constants.py index c3fd0dc5f..9893744c3 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -6,7 +6,6 @@ class MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum): - def __str__(self) -> str: return super().__str__() @@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum): UNAUTHORIZED = "401 Unauthorized" ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACTION_PROHIBITED = ( - "The requested action has been restricted as a security measure.") + "The requested action has been restricted as a security measure." + ) + + FILE_NOT_SENT = "FILE_NOT_SENT" NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."