feat: web rag support

This commit is contained in:
Timothy J. Baek 2024-01-26 22:17:28 -08:00
parent 5e672d9f79
commit 28226a6f97
5 changed files with 131 additions and 33 deletions

View File

@ -37,7 +37,7 @@ from typing import Optional
import uuid import uuid
import time import time
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user from utils.utils import get_current_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -124,10 +124,15 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
try: try:
loader = WebBaseLoader(form_data.url) loader = WebBaseLoader(form_data.url)
data = loader.load() data = loader.load()
store_data_in_vector_db(data, form_data.collection_name)
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name)
return { return {
"status": True, "status": True,
"collection_name": form_data.collection_name, "collection_name": collection_name,
"filename": form_data.url, "filename": form_data.url,
} }
except Exception as e: except Exception as e:

View File

@ -24,6 +24,16 @@ def calculate_sha256(file):
return sha256.hexdigest() return sha256.hexdigest()
def calculate_sha256_string(string):
# Create a new SHA-256 hash object
sha256_hash = hashlib.sha256()
# Update the hash object with the bytes of the input string
sha256_hash.update(string.encode("utf-8"))
# Get the hexadecimal representation of the hash
hashed_string = sha256_hash.hexdigest()
return hashed_string
def validate_email_format(email: str) -> bool: def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email): if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False return False

View File

@ -6,7 +6,7 @@
import Prompts from './MessageInput/PromptCommands.svelte'; import Prompts from './MessageInput/PromptCommands.svelte';
import Suggestions from './MessageInput/Suggestions.svelte'; import Suggestions from './MessageInput/Suggestions.svelte';
import { uploadDocToVectorDB } from '$lib/apis/rag'; import { uploadDocToVectorDB, uploadWebToVectorDB } from '$lib/apis/rag';
import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte'; import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte';
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants'; import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
import Documents from './MessageInput/Documents.svelte'; import Documents from './MessageInput/Documents.svelte';
@ -137,6 +137,33 @@
} }
}; };
const uploadWeb = async (url) => {
console.log(url);
const doc = {
type: 'doc',
name: url,
collection_name: '',
upload_status: false,
error: ''
};
try {
files = [...files, doc];
const res = await uploadWebToVectorDB(localStorage.token, '', url);
if (res) {
doc.upload_status = true;
doc.collection_name = res.collection_name;
files = files;
}
} catch (e) {
// Remove the failed doc from the files array
files = files.filter((f) => f.name !== url);
toast.error(e);
}
};
onMount(() => { onMount(() => {
const dropZone = document.querySelector('body'); const dropZone = document.querySelector('body');
@ -258,6 +285,10 @@
<Documents <Documents
bind:this={documentsElement} bind:this={documentsElement}
bind:prompt bind:prompt
on:url={(e) => {
console.log(e);
uploadWeb(e.detail);
}}
on:select={(e) => { on:select={(e) => {
console.log(e); console.log(e);
files = [ files = [

View File

@ -2,7 +2,7 @@
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
import { documents } from '$lib/stores'; import { documents } from '$lib/stores';
import { removeFirstHashWord } from '$lib/utils'; import { removeFirstHashWord, isValidHttpUrl } from '$lib/utils';
import { tick } from 'svelte'; import { tick } from 'svelte';
export let prompt = ''; export let prompt = '';
@ -37,9 +37,20 @@
chatInputElement?.focus(); chatInputElement?.focus();
await tick(); await tick();
}; };
const confirmSelectWeb = async (url) => {
dispatch('url', url);
prompt = removeFirstHashWord(prompt);
const chatInputElement = document.getElementById('chat-textarea');
await tick();
chatInputElement?.focus();
await tick();
};
</script> </script>
{#if filteredDocs.length > 0} {#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<div class="md:px-2 mb-3 text-left w-full"> <div class="md:px-2 mb-3 text-left w-full">
<div class="flex w-full rounded-lg border border-gray-100 dark:border-gray-700"> <div class="flex w-full rounded-lg border border-gray-100 dark:border-gray-700">
<div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-lg text-center"> <div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-lg text-center">
@ -55,6 +66,7 @@
: ''}" : ''}"
type="button" type="button"
on:click={() => { on:click={() => {
console.log(doc);
confirmSelect(doc); confirmSelect(doc);
}} }}
on:mousemove={() => { on:mousemove={() => {
@ -71,6 +83,25 @@
</div> </div>
</button> </button>
{/each} {/each}
{#if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<button
class="px-3 py-1.5 rounded-lg w-full text-left bg-gray-100 selected-command-option-button"
type="button"
on:click={() => {
const url = prompt.split(' ')?.at(0)?.substring(1);
if (isValidHttpUrl(url)) {
confirmSelectWeb(url);
}
}}
>
<div class=" font-medium text-black line-clamp-1">
{prompt.split(' ')?.at(0)?.substring(1)}
</div>
<div class=" text-xs text-gray-600 line-clamp-1">Web</div>
</button>
{/if}
</div> </div>
</div> </div>
</div> </div>

View File

@ -212,8 +212,12 @@ const convertOpenAIMessages = (convo) => {
const message = mapping[message_id]; const message = mapping[message_id];
currentId = message_id; currentId = message_id;
try { try {
if (messages.length == 0 && (message['message'] == null || if (
(message['message']['content']['parts']?.[0] == '' && message['message']['content']['text'] == null))) { messages.length == 0 &&
(message['message'] == null ||
(message['message']['content']['parts']?.[0] == '' &&
message['message']['content']['text'] == null))
) {
// Skip chat messages with no content // Skip chat messages with no content
continue; continue;
} else { } else {
@ -222,7 +226,10 @@ const convertOpenAIMessages = (convo) => {
parentId: lastId, parentId: lastId,
childrenIds: message['children'] || [], childrenIds: message['children'] || [],
role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user', role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user',
content: message['message']?.['content']?.['parts']?.[0] || message['message']?.['content']?.['text'] || '', content:
message['message']?.['content']?.['parts']?.[0] ||
message['message']?.['content']?.['text'] ||
'',
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
done: true, done: true,
context: null context: null
@ -231,7 +238,7 @@ const convertOpenAIMessages = (convo) => {
lastId = currentId; lastId = currentId;
} }
} catch (error) { } catch (error) {
console.log("Error with", message, "\nError:", error); console.log('Error with', message, '\nError:', error);
} }
} }
@ -256,31 +263,31 @@ const validateChat = (chat) => {
// Because ChatGPT sometimes has features we can't use like DALL-E or migh have corrupted messages, need to validate // Because ChatGPT sometimes has features we can't use like DALL-E or migh have corrupted messages, need to validate
const messages = chat.messages; const messages = chat.messages;
// Check if messages array is empty // Check if messages array is empty
if (messages.length === 0) { if (messages.length === 0) {
return false; return false;
} }
// Last message's children should be an empty array // Last message's children should be an empty array
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
if (lastMessage.childrenIds.length !== 0) { if (lastMessage.childrenIds.length !== 0) {
return false; return false;
} }
// First message's parent should be null // First message's parent should be null
const firstMessage = messages[0]; const firstMessage = messages[0];
if (firstMessage.parentId !== null) { if (firstMessage.parentId !== null) {
return false; return false;
} }
// Every message's content should be a string // Every message's content should be a string
for (let message of messages) { for (let message of messages) {
if (typeof message.content !== 'string') { if (typeof message.content !== 'string') {
return false; return false;
} }
} }
return true; return true;
}; };
export const convertOpenAIChats = (_chats) => { export const convertOpenAIChats = (_chats) => {
@ -298,8 +305,22 @@ export const convertOpenAIChats = (_chats) => {
chat: chat, chat: chat,
timestamp: convo['timestamp'] timestamp: convo['timestamp']
}); });
} else { failed ++} } else {
failed++;
}
} }
console.log(failed, "Conversations could not be imported"); console.log(failed, 'Conversations could not be imported');
return chats; return chats;
}; };
export const isValidHttpUrl = (string) => {
let url;
try {
url = new URL(string);
} catch (_) {
return false;
}
return url.protocol === 'http:' || url.protocol === 'https:';
};