refac: embeddings function

This commit is contained in:
Timothy J. Baek 2024-10-09 11:41:35 -07:00
parent b38e2fab32
commit 451f1bae15

View File

@ -272,26 +272,26 @@ def get_embedding_function(
return lambda query: embedding_function.encode(query).tolist() return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]: elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama": if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings( func = lambda query: generate_embeddings(
model=embedding_model, model=embedding_model,
input=query, text=query,
) )
elif embedding_engine == "openai": elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings( func = lambda query: generate_embeddings(
model=embedding_model, model=embedding_model,
text=query, text=query,
key=openai_key, key=openai_key,
url=openai_url, url=openai_url,
) )
def generate_multiple(query, f): def generate_multiple(query, func):
if isinstance(query, list): if isinstance(query, list):
embeddings = [] embeddings = []
for i in range(0, len(query), embedding_batch_size): for i in range(0, len(query), embedding_batch_size):
embeddings.extend(f(query[i : i + embedding_batch_size])) embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings return embeddings
else: else:
return f(query) return func(query)
return lambda query: generate_multiple(query, func) return lambda query: generate_multiple(query, func)
@ -438,20 +438,6 @@ def get_model_path(model: str, update_model: bool = False):
return model return model
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings( def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
@ -475,19 +461,31 @@ def generate_openai_batch_embeddings(
return None return None
def generate_ollama_embeddings( def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
model: str, input: list[str] if engine == "ollama":
) -> Optional[list[list[float]]]: if isinstance(text, list):
if isinstance(input, list): embeddings = generate_ollama_batch_embeddings(
embeddings = generate_ollama_batch_embeddings( GenerateEmbedForm(**{"model": model, "input": text})
GenerateEmbedForm(**{"model": model, "input": input}) )
) else:
else: embeddings = generate_ollama_batch_embeddings(
embeddings = generate_ollama_batch_embeddings( GenerateEmbedForm(**{"model": model, "input": [text]})
GenerateEmbedForm(**{"model": model, "input": [input]}) )
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
) )
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
return embeddings["embeddings"] if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
import operator import operator