This commit is contained in:
Timothy Jaeryang Baek 2025-05-30 01:19:56 +04:00
parent 036ce12dd9
commit 9306ae5972

View File

@ -818,22 +818,10 @@ def generate_embeddings(
text = f"{prefix}{text}" text = f"{prefix}{text}"
if engine == "ollama": if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings( embeddings = generate_ollama_batch_embeddings(
**{ **{
"model": model, "model": model,
"texts": text, "texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
else:
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": [text],
"url": url, "url": url,
"key": key, "key": key,
"prefix": prefix, "prefix": prefix,
@ -842,35 +830,17 @@ def generate_embeddings(
) )
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai": elif engine == "openai":
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings( embeddings = generate_openai_batch_embeddings(
model, text, url, key, prefix, user model, text if isinstance(text, list) else [text], url, key, prefix, user
)
else:
embeddings = generate_openai_batch_embeddings(
model, [text], url, key, prefix, user
) )
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai": elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "") azure_api_version = kwargs.get("azure_api_version", "")
if isinstance(text, list):
embeddings = generate_azure_openai_batch_embeddings( embeddings = generate_azure_openai_batch_embeddings(
model, model,
text, text if isinstance(text, list) else [text],
url, url,
key, key,
model,
azure_api_version,
prefix,
user,
)
else:
embeddings = generate_azure_openai_batch_embeddings(
model,
[text],
url,
key,
model,
azure_api_version, azure_api_version,
prefix, prefix,
user, user,