From abfcceecefa6251ab780558eb46931ffed9b9e78 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 10 Apr 2024 00:46:09 -0700 Subject: [PATCH] refac --- backend/apps/rag/main.py | 37 +++++++++---------- .../documents/Settings/General.svelte | 14 +++++-- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 7bb9ee0ee..8846e4dce 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -142,43 +142,40 @@ class EmbeddingModelUpdateForm(BaseModel): async def update_embedding_model( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): - status = True - old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model log.debug(f"form_data.embedding_model: {form_data.embedding_model}") log.info( f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) + embedding_model_path = None + sentence_transformer_ef = None try: - app.state.RAG_EMBEDDING_MODEL_PATH = get_embedding_model_path( - app.state.RAG_EMBEDDING_MODEL, True - ) - app.state.sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=app.state.RAG_EMBEDDING_MODEL_PATH, - device=DEVICE_TYPE, + embedding_model_path = get_embedding_model_path(form_data.embedding_model, True) + if app.state.RAG_EMBEDDING_MODEL_PATH != embedding_model_path: + sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=embedding_model_path, + device=DEVICE_TYPE, + ) ) - ) except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=e, + detail=ERROR_MESSAGES.DEFAULT(e), ) - if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: - status = False + if sentence_transformer_ef: + app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_path + app.state.sentence_transformer_ef = sentence_transformer_ef - log.debug( - f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}" - ) - log.debug(f"old_model_path: {old_model_path}") - log.debug(f"status: {status}") + log.debug( + f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}" + ) return { - "status": status, + "status": sentence_transformer_ef != None, "embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, } diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index 17bd3ad04..a37ba1da9 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -35,6 +35,9 @@ k: 4 }; + let embeddingModelConfig = { + embedding_model: '' + }; let embeddingModel = ''; const scanHandler = async () => { @@ -61,7 +64,13 @@ console.log('Update embedding model attempt:', embeddingModel); updateEmbeddingModelLoading = true; - const res = await updateEmbeddingModel(localStorage.token, { embedding_model: embeddingModel }); + const res = await updateEmbeddingModel(localStorage.token, { + embedding_model: embeddingModel + }).catch((error) => { + toast.error(error); + embeddingModel = embeddingModelConfig.embedding_model; + return null; + }); updateEmbeddingModelLoading = false; if (res) { @@ -99,8 +108,7 @@ chunkOverlap = res.chunk.chunk_overlap; } - const embeddingModelConfig = await getEmbeddingModel(localStorage.token); - + embeddingModelConfig = await getEmbeddingModel(localStorage.token); embeddingModel = embeddingModelConfig.embedding_model; querySettings = await getQuerySettings(localStorage.token);