mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 05:24:02 +00:00
feat: chromadb vector store api
This commit is contained in:
parent
b2c9f6dff8
commit
784b369cc9
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@ -5,4 +5,5 @@ uploads
|
||||
.ipynb_checkpoints
|
||||
*.db
|
||||
_test
|
||||
Pipfile
|
||||
Pipfile
|
||||
data/*
|
@ -1,9 +1,25 @@
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from apps.web.routers import auths, users, chats, modelfiles, utils
|
||||
from config import WEBUI_VERSION, WEBUI_AUTH
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from langchain.document_loaders import WebBaseLoader, TextLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
import uuid
|
||||
|
||||
from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=EMBED_MODEL
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -18,6 +34,84 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
class StoreWebForm(BaseModel):
|
||||
url: str
|
||||
collection_name: Optional[str] = "test"
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name):
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
|
||||
)
|
||||
docs = text_splitter.split_documents(data)
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name, embedding_function=EMBEDDING_FUNC
|
||||
)
|
||||
|
||||
collection.add(
|
||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@app.get("/query/{collection_name}")
|
||||
def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
)
|
||||
result = collection.query(query_texts=[query], n_results=k)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/web")
|
||||
def store_web(form_data: StoreWebForm):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
try:
|
||||
loader = WebBaseLoader(form_data.url)
|
||||
data = loader.load()
|
||||
store_data_in_vector_db(data, form_data.collection_name)
|
||||
return {"status": True}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/doc")
|
||||
def store_doc(file: UploadFile = File(...)):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
|
||||
try:
|
||||
print(file)
|
||||
file.filename = f"{uuid.uuid4()}-{file.filename}"
|
||||
contents = file.file.read()
|
||||
with open(f"./data/{file.filename}", "wb") as f:
|
||||
f.write(contents)
|
||||
f.close()
|
||||
|
||||
# loader = WebBaseLoader(form_data.url)
|
||||
# data = loader.load()
|
||||
# store_data_in_vector_db(data, form_data.collection_name)
|
||||
return {"status": True}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def reset_vector_db():
|
||||
CHROMA_CLIENT.reset()
|
||||
return {"status": True}
|
||||
|
@ -1,11 +1,11 @@
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
import os
|
||||
import chromadb
|
||||
|
||||
from secrets import token_bytes
|
||||
from base64 import b64encode
|
||||
|
||||
import os
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
load_dotenv(find_dotenv("../.env"))
|
||||
|
||||
@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev")
|
||||
# OLLAMA_API_BASE_URL
|
||||
####################################
|
||||
|
||||
OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
|
||||
"http://localhost:11434/api")
|
||||
OLLAMA_API_BASE_URL = os.environ.get(
|
||||
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
|
||||
)
|
||||
|
||||
if ENV == "prod":
|
||||
if OLLAMA_API_BASE_URL == "/ollama/api":
|
||||
@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t")
|
||||
|
||||
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
|
||||
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
||||
|
||||
####################################
|
||||
# RAG
|
||||
####################################
|
||||
|
||||
CHROMA_DATA_PATH = "./data/vector_db"
|
||||
EMBED_MODEL = "all-MiniLM-L6-v2"
|
||||
CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
|
||||
CHUNK_SIZE = 1500
|
||||
CHUNK_OVERLAP = 100
|
||||
|
@ -6,7 +6,6 @@ class MESSAGES(str, Enum):
|
||||
|
||||
|
||||
class ERROR_MESSAGES(str, Enum):
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum):
|
||||
UNAUTHORIZED = "401 Unauthorized"
|
||||
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
|
||||
ACTION_PROHIBITED = (
|
||||
"The requested action has been restricted as a security measure.")
|
||||
"The requested action has been restricted as a security measure."
|
||||
)
|
||||
|
||||
FILE_NOT_SENT = "FILE_NOT_SENT"
|
||||
NOT_FOUND = "We could not find what you're looking for :/"
|
||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
||||
|
Loading…
Reference in New Issue
Block a user