refac: embeddings function

This commit is contained in:
Timothy J. Baek 2024-10-09 11:41:35 -07:00
parent b38e2fab32
commit 451f1bae15

View File

@ -272,26 +272,26 @@ def get_embedding_function(
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
func = lambda query: generate_embeddings(
model=embedding_model,
input=query,
text=query,
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
func = lambda query: generate_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def generate_multiple(query, f):
def generate_multiple(query, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(f(query[i : i + embedding_batch_size]))
embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings
else:
return f(query)
return func(query)
return lambda query: generate_multiple(query, func)
@ -438,20 +438,6 @@ def get_model_path(model: str, update_model: bool = False):
return model
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
@ -475,19 +461,31 @@ def generate_openai_batch_embeddings(
return None
def generate_ollama_embeddings(
model: str, input: list[str]
) -> Optional[list[list[float]]]:
if isinstance(input, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": input})
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [input]})
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
return embeddings["embeddings"]
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
import operator