From 582d11f191e464af0b45753b0d2b900751fef238 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 10 Apr 2024 00:59:05 -0700 Subject: [PATCH] refac: RAG_EMBEDDING_MODEL_PATH removed --- backend/apps/rag/main.py | 47 +++++++++++++++------------------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 8739fa6c4..219a26f7b 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -80,16 +80,15 @@ app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.RAG_EMBEDDING_MODEL_PATH = get_embedding_model_path( - app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE -) app.state.TOP_K = 4 app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=app.state.RAG_EMBEDDING_MODEL_PATH, + model_name=get_embedding_model_path( + app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE + ), device=DEVICE_TYPE, ) ) @@ -130,7 +129,6 @@ async def get_embedding_model(user=Depends(get_admin_user)): return { "status": True, "embedding_model": app.state.RAG_EMBEDDING_MODEL, - "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, } @@ -143,22 +141,26 @@ async def update_embedding_model( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): - log.debug(f"form_data.embedding_model: {form_data.embedding_model}") log.info( f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) - embedding_model_path = None - sentence_transformer_ef = None try: - embedding_model_path = get_embedding_model_path(form_data.embedding_model, True) - if app.state.RAG_EMBEDDING_MODEL_PATH != embedding_model_path: - sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=embedding_model_path, - device=DEVICE_TYPE, - ) + sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=get_embedding_model_path(form_data.embedding_model, True), + device=DEVICE_TYPE, ) + ) + + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = sentence_transformer_ef + + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( @@ -166,21 +168,6 @@ async def update_embedding_model( detail=ERROR_MESSAGES.DEFAULT(e), ) - if sentence_transformer_ef: - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model - app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_path - app.state.sentence_transformer_ef = sentence_transformer_ef - - log.debug( - f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}" - ) - - return { - "status": sentence_transformer_ef != None, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, - "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, - } - @app.get("/config") async def get_rag_config(user=Depends(get_admin_user)):