feat: full integration

This commit is contained in:
Timothy J. Baek 2024-01-07 01:40:36 -08:00
parent 28c1192ac0
commit 9634e2da3e
6 changed files with 116 additions and 25 deletions

View File

@ -9,6 +9,7 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from chromadb.utils import embedding_functions
@ -23,7 +24,7 @@ from typing import Optional
import uuid
from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES
EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
@ -51,7 +52,7 @@ class StoreWebForm(CollectionNameForm):
url: str
def store_data_in_vector_db(data, collection_name):
def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
)
@ -60,13 +61,22 @@ def store_data_in_vector_db(data, collection_name):
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=EMBEDDING_FUNC
)
try:
collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=EMBEDDING_FUNC
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
print(e.__class__.__name__)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
@app.get("/")
@ -116,7 +126,7 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
try:
filename = file.filename
file_path = f"./data/{filename}"
file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
@ -128,8 +138,15 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
loader = TextLoader(file_path)
data = loader.load()
store_data_in_vector_db(data, collection_name)
return {"status": True, "collection_name": collection_name}
result = store_data_in_vector_db(data, collection_name)
if result:
return {"status": True, "collection_name": collection_name}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
except Exception as e:
print(e)
raise HTTPException(
@ -138,6 +155,27 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
)
@app.get("/reset/db")
def reset_vector_db():
CHROMA_CLIENT.reset()
@app.get("/reset")
def reset():
folder = f"{UPLOAD_DIR}"
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e))
try:
CHROMA_CLIENT.reset()
except Exception as e:
print(e)
return {"status": True}

View File

@ -1,14 +1,31 @@
from dotenv import load_dotenv, find_dotenv
import os
import chromadb
from chromadb import Settings
from secrets import token_bytes
from base64 import b64encode
from constants import ERROR_MESSAGES
from pathlib import Path
load_dotenv(find_dotenv("../.env"))
####################################
# File Upload
####################################
UPLOAD_DIR = "./data/uploads"
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
####################################
# ENV (dev,test,prod)
####################################
@ -64,6 +81,8 @@ if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
CHROMA_DATA_PATH = "./data/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True)
)
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100

View File

@ -124,16 +124,16 @@
reader.readAsDataURL(file);
} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
console.log(file);
const hash = await calculateSHA256(file);
// const res = uploadDocToVectorDB(localStorage.token, hash,file);
const hash = (await calculateSHA256(file)).substring(0, 63);
const res = await uploadDocToVectorDB(localStorage.token, hash, file);
if (true) {
if (res) {
files = [
...files,
{
type: 'doc',
name: file.name,
collection_name: hash
collection_name: res.collection_name
}
];
}
@ -243,16 +243,16 @@
reader.readAsDataURL(file);
} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
console.log(file);
const hash = await calculateSHA256(file);
// const res = uploadDocToVectorDB(localStorage.token,hash,file);
const hash = (await calculateSHA256(file)).substring(0, 63);
const res = await uploadDocToVectorDB(localStorage.token, hash, file);
if (true) {
if (res) {
files = [
...files,
{
type: 'doc',
name: file.name,
collection_name: hash
collection_name: res.collection_name
}
];
filesInputElement.value = '';
@ -280,7 +280,7 @@
<img src={file.url} alt="input" class=" h-16 w-16 rounded-xl object-cover" />
{:else if file.type === 'doc'}
<div
class="h-16 w-[15rem] flex items-center space-x-3 px-2 bg-gray-600 rounded-xl"
class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 bg-gray-600 rounded-xl"
>
<div class="p-2.5 bg-red-400 rounded-lg">
<svg

View File

@ -53,11 +53,41 @@
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:my-0 prose-p:-mb-4 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-6 prose-li:-mb-4 whitespace-pre-line"
>
{#if message.files}
<div class="my-3 w-full flex overflow-x-auto space-x-2">
<div class="my-2.5 w-full flex overflow-x-auto space-x-2 flex-wrap">
{#each message.files as file}
<div>
{#if file.type === 'image'}
<img src={file.url} alt="input" class=" max-h-96 rounded-lg" draggable="false" />
{:else if file.type === 'doc'}
<div
class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 bg-gray-600 rounded-xl"
>
<div class="p-2.5 bg-red-400 rounded-lg">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="w-6 h-6"
>
<path
fill-rule="evenodd"
d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
clip-rule="evenodd"
/>
<path
d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
/>
</svg>
</div>
<div class="flex flex-col justify-center -space-y-0.5">
<div class=" text-gray-100 text-sm line-clamp-1">
{file.name}
</div>
<div class=" text-gray-500 text-sm">Document</div>
</div>
</div>
{/if}
</div>
{/each}

View File

@ -129,7 +129,6 @@ export const findWordIndices = (text) => {
};
export const calculateSHA256 = async (file) => {
console.log(file);
// Create a FileReader to read the file asynchronously
const reader = new FileReader();
@ -156,7 +155,7 @@ export const calculateSHA256 = async (file) => {
const hashArray = Array.from(new Uint8Array(hashBuffer));
const hashHex = hashArray.map((byte) => byte.toString(16).padStart(2, '0')).join('');
return `sha256:${hashHex}`;
return `${hashHex}`;
} catch (error) {
console.error('Error calculating SHA-256 hash:', error);
throw error;

View File

@ -186,8 +186,11 @@
const _chatId = JSON.parse(JSON.stringify($chatId));
// TODO: update below to include all ancestral files
const docs = history.messages[parentId].files.filter((item) => item.type === 'file');
console.log(history.messages[parentId]);
const docs = history.messages[parentId]?.files?.filter((item) => item.type === 'doc') ?? [];
console.log(docs);
if (docs.length > 0) {
const query = history.messages[parentId].content;
@ -207,6 +210,8 @@
return `${a}${context.documents.join(' ')}\n`;
}, '');
console.log(contextString);
history.messages[parentId].raContent = RAGTemplate(contextString, query);
history.messages[parentId].contexts = relevantContexts;
await tick();