From 451f1bae1551789c273e02b3eef915fef1a12549 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 9 Oct 2024 11:41:35 -0700 Subject: [PATCH] refac: embeddings function --- backend/open_webui/apps/retrieval/utils.py | 60 +++++++++++----------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 992d46788..8364728d5 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -272,26 +272,26 @@ def get_embedding_function( return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: if embedding_engine == "ollama": - func = lambda query: generate_ollama_embeddings( + func = lambda query: generate_embeddings( model=embedding_model, - input=query, + text=query, ) elif embedding_engine == "openai": - func = lambda query: generate_openai_embeddings( + func = lambda query: generate_embeddings( model=embedding_model, text=query, key=openai_key, url=openai_url, ) - def generate_multiple(query, f): + def generate_multiple(query, func): if isinstance(query, list): embeddings = [] for i in range(0, len(query), embedding_batch_size): - embeddings.extend(f(query[i : i + embedding_batch_size])) + embeddings.extend(func(query[i : i + embedding_batch_size])) return embeddings else: - return f(query) + return func(query) return lambda query: generate_multiple(query, func) @@ -438,20 +438,6 @@ def get_model_path(model: str, update_model: bool = False): return model -def generate_openai_embeddings( - model: str, - text: Union[str, list[str]], - key: str, - url: str = "https://api.openai.com/v1", -): - if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, key, url) - else: - embeddings = generate_openai_batch_embeddings(model, [text], key, url) - - return embeddings[0] if isinstance(text, str) else embeddings - - def generate_openai_batch_embeddings( model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" ) -> Optional[list[list[float]]]: @@ -475,19 +461,31 @@ def generate_openai_batch_embeddings( return None -def generate_ollama_embeddings( - model: str, input: list[str] -) -> Optional[list[list[float]]]: - if isinstance(input, list): - embeddings = generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": input}) - ) - else: - embeddings = generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": [input]}) +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + if engine == "ollama": + if isinstance(text, list): + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": text}) + ) + else: + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": [text]}) + ) + return ( + embeddings["embeddings"][0] + if isinstance(text, str) + else embeddings["embeddings"] ) + elif engine == "openai": + key = kwargs.get("key", "") + url = kwargs.get("url", "https://api.openai.com/v1") - return embeddings["embeddings"] + if isinstance(text, list): + embeddings = generate_openai_batch_embeddings(model, text, key, url) + else: + embeddings = generate_openai_batch_embeddings(model, [text], key, url) + + return embeddings[0] if isinstance(text, str) else embeddings import operator