From b48e73fa43bc574b634e7632f6133093eb221429 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 14 Apr 2024 19:15:39 -0400 Subject: [PATCH] feat: openai embeddings support --- backend/apps/rag/main.py | 158 +++++++++++++++++++++++++------------- backend/apps/rag/utils.py | 23 ++++++ 2 files changed, 127 insertions(+), 54 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 976c7735b..118600329 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -53,6 +53,7 @@ from apps.rag.utils import ( query_collection, query_embeddings_collection, get_embedding_model_path, + generate_openai_embeddings, ) from utils.misc import ( @@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com" +app.state.RAG_OPENAI_API_KEY = "" app.state.PDF_EXTRACT_IMAGES = False @@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)): "status": True, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "openai_config": { + "url": app.state.RAG_OPENAI_API_BASE_URL, + "key": app.state.RAG_OPENAI_API_KEY, + }, } +class OpenAIConfigForm(BaseModel): + url: str + key: str + + class EmbeddingModelUpdateForm(BaseModel): + openai_config: Optional[OpenAIConfigForm] = None embedding_engine: str embedding_model: str @@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel): async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): - log.info( f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) - try: app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - if app.state.RAG_EMBEDDING_ENGINE == "ollama": + if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.sentence_transformer_ef = None + + if form_data.openai_config != None: + app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url + app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key else: sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( @@ -183,6 +198,10 @@ async def update_embedding_config( "status": True, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "openai_config": { + "url": app.state.RAG_OPENAI_API_BASE_URL, + "key": app.state.RAG_OPENAI_API_KEY, + }, } except Exception as e: @@ -275,28 +294,37 @@ def query_doc_handler( ): try: - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": app.state.RAG_EMBEDDING_MODEL, - "prompt": form_data.query, - } - ) - ) - - return query_embeddings_doc( - collection_name=form_data.collection_name, - query_embeddings=query_embeddings, - k=form_data.k if form_data.k else app.state.TOP_K, - ) - else: + if app.state.RAG_EMBEDDING_ENGINE == "": return query_doc( collection_name=form_data.collection_name, query=form_data.query, k=form_data.k if form_data.k else app.state.TOP_K, embedding_function=app.state.sentence_transformer_ef, ) + else: + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) + ) + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + query_embeddings = generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=form_data.query, + key=app.state.RAG_OPENAI_API_KEY, + url=app.state.RAG_OPENAI_API_BASE_URL, + ) + + return query_embeddings_doc( + collection_name=form_data.collection_name, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) + except Exception as e: log.exception(e) raise HTTPException( @@ -317,28 +345,38 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": app.state.RAG_EMBEDDING_MODEL, - "prompt": form_data.query, - } - ) - ) - - return query_embeddings_collection( - collection_names=form_data.collection_names, - query_embeddings=query_embeddings, - k=form_data.k if form_data.k else app.state.TOP_K, - ) - else: + if app.state.RAG_EMBEDDING_ENGINE == "": return query_collection( collection_names=form_data.collection_names, query=form_data.query, k=form_data.k if form_data.k else app.state.TOP_K, embedding_function=app.state.sentence_transformer_ef, ) + else: + + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) + ) + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + query_embeddings = generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=form_data.query, + key=app.state.RAG_OPENAI_API_KEY, + url=app.state.RAG_OPENAI_API_BASE_URL, + ) + + return query_embeddings_collection( + collection_names=form_data.collection_names, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) + except Exception as e: log.exception(e) raise HTTPException( @@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"deleting existing collection {collection_name}") CHROMA_CLIENT.delete_collection(name=collection_name) - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - collection = CHROMA_CLIENT.create_collection(name=collection_name) - - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - embeddings=[ - generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{"model": RAG_EMBEDDING_MODEL, "prompt": text} - ) - ) - for text in texts - ], - ): - collection.add(*batch) - else: + if app.state.RAG_EMBEDDING_ENGINE == "": collection = CHROMA_CLIENT.create_collection( name=collection_name, @@ -446,7 +467,36 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ): collection.add(*batch) - return True + else: + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + embeddings = [ + generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} + ) + ) + for text in texts + ] + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + embeddings = [ + generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=text, + key=app.state.RAG_OPENAI_API_KEY, + url=app.state.RAG_OPENAI_API_BASE_URL, + ) + for text in texts + ] + + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + embeddings=embeddings, + ): + collection.add(*batch) + + return True except Exception as e: log.exception(e) if e.__class__.__name__ == "UniqueConstraintError": diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 17d8e4a9a..a0956e2fc 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -269,3 +269,26 @@ def get_embedding_model_path( except Exception as e: log.exception(f"Cannot determine embedding model snapshot path: {e}") return embedding_model + + +def generate_openai_embeddings( + model: str, text: str, key: str, url: str = "https://api.openai.com" +): + try: + r = requests.post( + f"{url}/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": text, "model": model}, + ) + r.raise_for_status() + data = r.json() + if "data" in data: + return data["data"][0]["embedding"] + else: + raise "Something went wrong :/" + except Exception as e: + print(e) + return None