mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: add RAG_EMBEDDING_OPENAI_BATCH_SIZE to batch multiple embeddings
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from apps.ollama.main import (
|
||||
generate_ollama_embeddings,
|
||||
@@ -21,17 +21,7 @@ from langchain.retrievers import (
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
CHROMA_CLIENT,
|
||||
SEARXNG_QUERY_URL,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
SERPSTACK_API_KEY,
|
||||
SERPSTACK_HTTPS,
|
||||
SERPER_API_KEY,
|
||||
)
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -209,6 +199,7 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
@@ -232,7 +223,13 @@ def get_embedding_function(
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
return [f(q) for q in query]
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
@@ -413,8 +410,22 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
key: str,
|
||||
url: str = "https://api.openai.com/v1",
|
||||
):
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
@@ -422,12 +433,12 @@ def generate_openai_embeddings(
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": text, "model": model},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return data["data"][0]["embedding"]
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user