feat: switch to config proxy, remove config_get/set

This commit is contained in:
Jun Siang Cheah
2024-05-10 15:03:24 +08:00
parent f712c90019
commit 298e6848b3
11 changed files with 340 additions and 379 deletions

View File

@@ -93,8 +93,7 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
config_set,
config_get,
AppConfig,
)
from constants import ERROR_MESSAGES
@@ -104,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config = AppConfig()
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
@@ -135,7 +136,7 @@ def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "":
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
@@ -160,22 +161,22 @@ def update_reranking_model(
update_embedding_model(
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
config_get(app.state.RAG_RERANKING_MODEL),
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
config_get(app.state.RAG_EMBEDDING_ENGINE),
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
config_get(app.state.OPENAI_API_KEY),
config_get(app.state.OPENAI_API_BASE_URL),
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
origins = ["*"]
@@ -202,12 +203,12 @@ class UrlForm(CollectionNameForm):
async def get_status():
return {
"status": True,
"chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
"template": config_get(app.state.RAG_TEMPLATE),
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
"template": app.state.config.RAG_TEMPLATE,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
@@ -215,11 +216,11 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": config_get(app.state.OPENAI_API_BASE_URL),
"key": config_get(app.state.OPENAI_API_KEY),
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
}
@@ -228,7 +229,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
async def get_reraanking_config(user=Depends(get_admin_user)):
return {
"status": True,
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
@@ -248,34 +249,34 @@ async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine)
config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model)
app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]:
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url)
config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key)
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True)
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
app.state.EMBEDDING_FUNCTION = get_embedding_function(
config_get(app.state.RAG_EMBEDDING_ENGINE),
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
config_get(app.state.OPENAI_API_KEY),
config_get(app.state.OPENAI_API_BASE_URL),
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
return {
"status": True,
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": config_get(app.state.OPENAI_API_BASE_URL),
"key": config_get(app.state.OPENAI_API_KEY),
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
}
except Exception as e:
@@ -295,16 +296,16 @@ async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model)
app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True)
update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
return {
"status": True,
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
@@ -318,16 +319,14 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": config_get(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
),
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -352,49 +351,34 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
config_set(
app.state.PDF_EXTRACT_IMAGES,
(
form_data.pdf_extract_images
if form_data.pdf_extract_images is not None
else config_get(app.state.PDF_EXTRACT_IMAGES)
),
app.state.config.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
if form_data.pdf_extract_images is not None
else app.state.config.PDF_EXTRACT_IMAGES
)
config_set(
app.state.CHUNK_SIZE,
(
form_data.chunk.chunk_size
if form_data.chunk is not None
else config_get(app.state.CHUNK_SIZE)
),
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
config_set(
app.state.CHUNK_OVERLAP,
(
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else config_get(app.state.CHUNK_OVERLAP)
),
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
config_set(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
(
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION)
),
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
config_set(
app.state.YOUTUBE_LOADER_LANGUAGE,
(
form_data.youtube.language
if form_data.youtube is not None
else config_get(app.state.YOUTUBE_LOADER_LANGUAGE)
),
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
@@ -405,16 +389,14 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
return {
"status": True,
"pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": config_get(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
),
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -424,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": config_get(app.state.RAG_TEMPLATE),
"template": app.state.config.RAG_TEMPLATE,
}
@@ -432,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
"template": config_get(app.state.RAG_TEMPLATE),
"k": config_get(app.state.TOP_K),
"r": config_get(app.state.RELEVANCE_THRESHOLD),
"hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
"template": app.state.config.RAG_TEMPLATE,
"k": app.state.config.TOP_K,
"r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -450,22 +432,20 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
config_set(
app.state.RAG_TEMPLATE,
app.state.config.RAG_TEMPLATE = (
form_data.template if form_data.template else RAG_TEMPLATE,
)
config_set(app.state.TOP_K, form_data.k if form_data.k else 4)
config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0)
config_set(
app.state.ENABLE_RAG_HYBRID_SEARCH,
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False,
)
return {
"status": True,
"template": config_get(app.state.RAG_TEMPLATE),
"k": config_get(app.state.TOP_K),
"r": config_get(app.state.RELEVANCE_THRESHOLD),
"hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
"template": app.state.config.RAG_TEMPLATE,
"k": app.state.config.TOP_K,
"r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
@@ -483,17 +463,15 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=(
form_data.r
if form_data.r
else config_get(app.state.RELEVANCE_THRESHOLD)
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
)
else:
@@ -501,7 +479,7 @@ def query_doc_handler(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
log.exception(e)
@@ -525,17 +503,15 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=(
form_data.r
if form_data.r
else config_get(app.state.RELEVANCE_THRESHOLD)
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
)
else:
@@ -543,7 +519,7 @@ def query_collection_handler(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
@@ -560,8 +536,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION),
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
)
data = loader.load()
@@ -589,7 +565,7 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
try:
loader = get_web_loader(
form_data.url,
verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION),
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
)
data = loader.load()
@@ -645,8 +621,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config_get(app.state.CHUNK_SIZE),
chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
@@ -663,8 +639,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config_get(app.state.CHUNK_SIZE),
chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
@@ -687,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
config_get(app.state.RAG_EMBEDDING_ENGINE),
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
config_get(app.state.OPENAI_API_KEY),
config_get(app.state.OPENAI_API_BASE_URL),
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@@ -766,7 +742,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES)
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
)
elif file_ext == "csv":
loader = CSVLoader(file_path)