Merge branch 'dev' into feat/backend-web-search

This commit is contained in:
Jun Siang Cheah
2024-05-14 14:03:23 +08:00
48 changed files with 2650 additions and 1030 deletions

View File

@@ -95,6 +95,7 @@ from config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
AppConfig,
)
from constants import ERROR_MESSAGES
@@ -104,30 +105,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config = AppConfig()
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
@@ -135,7 +138,7 @@ def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
@@ -160,22 +163,22 @@ def update_reranking_model(
update_embedding_model(
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.RAG_RERANKING_MODEL,
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
origins = ["*"]
@@ -206,12 +209,12 @@ class SearchForm(CollectionNameForm):
async def get_status():
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
"template": app.state.config.RAG_TEMPLATE,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
@@ -219,18 +222,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
}
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
return {
"status": True,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
class OpenAIConfigForm(BaseModel):
@@ -249,34 +255,34 @@ async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
}
except Exception as e:
@@ -296,16 +302,16 @@ async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
@@ -319,14 +325,14 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -351,50 +357,52 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = (
app.state.config.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
if form_data.pdf_extract_images is not None
else app.state.config.PDF_EXTRACT_IMAGES
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
app.state.CHUNK_OVERLAP = (
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.YOUTUBE_LOADER_LANGUAGE = (
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_LANGUAGE
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube != None
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -404,7 +412,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"template": app.state.config.RAG_TEMPLATE,
}
@@ -412,10 +420,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
"template": app.state.config.RAG_TEMPLATE,
"k": app.state.config.TOP_K,
"r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -430,16 +438,20 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
app.state.config.RAG_TEMPLATE = (
form_data.template if form_data.template else RAG_TEMPLATE,
)
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False,
)
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
"template": app.state.config.RAG_TEMPLATE,
"k": app.state.config.TOP_K,
"r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -457,21 +469,23 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
if app.state.ENABLE_RAG_HYBRID_SEARCH:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
log.exception(e)
@@ -495,21 +509,23 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
if app.state.ENABLE_RAG_HYBRID_SEARCH:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
@@ -526,7 +542,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE,
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
)
data = loader.load()
@@ -554,7 +570,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
form_data.url,
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
)
data = loader.load()
@@ -659,8 +676,8 @@ def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
@@ -677,8 +694,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
@@ -701,11 +718,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@@ -779,7 +796,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":