diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py
index 33d984655..eed74064f 100644
--- a/backend/open_webui/apps/ollama/main.py
+++ b/backend/open_webui/apps/ollama/main.py
@@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel):
model: str
- input: str
- truncate: Optional[bool]
+ input: list[str]
+ truncate: Optional[bool] = None
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
@@ -560,48 +560,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
- if url_idx is None:
- model = form_data.model
-
- if ":" not in model:
- model = f"{model}:latest"
-
- if model in app.state.MODELS:
- url_idx = random.choice(app.state.MODELS[model]["urls"])
- else:
- raise HTTPException(
- status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
- )
-
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
- log.info(f"url: {url}")
-
- r = requests.request(
- method="POST",
- url=f"{url}/api/embed",
- headers={"Content-Type": "application/json"},
- data=form_data.model_dump_json(exclude_none=True).encode(),
- )
- try:
- r.raise_for_status()
-
- return r.json()
- except Exception as e:
- log.exception(e)
- error_detail = "Open WebUI: Server Connection Error"
- if r is not None:
- try:
- res = r.json()
- if "error" in res:
- error_detail = f"Ollama: {res['error']}"
- except Exception:
- error_detail = f"Ollama: {e}"
-
- raise HTTPException(
- status_code=r.status_code if r else 500,
- detail=error_detail,
- )
+ return generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings")
@@ -611,48 +570,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
- if url_idx is None:
- model = form_data.model
-
- if ":" not in model:
- model = f"{model}:latest"
-
- if model in app.state.MODELS:
- url_idx = random.choice(app.state.MODELS[model]["urls"])
- else:
- raise HTTPException(
- status_code=400,
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
- )
-
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
- log.info(f"url: {url}")
-
- r = requests.request(
- method="POST",
- url=f"{url}/api/embeddings",
- headers={"Content-Type": "application/json"},
- data=form_data.model_dump_json(exclude_none=True).encode(),
- )
- try:
- r.raise_for_status()
-
- return r.json()
- except Exception as e:
- log.exception(e)
- error_detail = "Open WebUI: Server Connection Error"
- if r is not None:
- try:
- res = r.json()
- if "error" in res:
- error_detail = f"Ollama: {res['error']}"
- except Exception:
- error_detail = f"Ollama: {e}"
-
- raise HTTPException(
- status_code=r.status_code if r else 500,
- detail=error_detail,
- )
+ return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings(
@@ -692,7 +610,64 @@ def generate_ollama_embeddings(
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
- return data["embedding"]
+ return data
+ else:
+ raise Exception("Something went wrong :/")
+ except Exception as e:
+ log.exception(e)
+ error_detail = "Open WebUI: Server Connection Error"
+ if r is not None:
+ try:
+ res = r.json()
+ if "error" in res:
+ error_detail = f"Ollama: {res['error']}"
+ except Exception:
+ error_detail = f"Ollama: {e}"
+
+ raise HTTPException(
+ status_code=r.status_code if r else 500,
+ detail=error_detail,
+ )
+
+
+def generate_ollama_batch_embeddings(
+ form_data: GenerateEmbedForm,
+ url_idx: Optional[int] = None,
+):
+ log.info(f"generate_ollama_batch_embeddings {form_data}")
+
+ if url_idx is None:
+ model = form_data.model
+
+ if ":" not in model:
+ model = f"{model}:latest"
+
+ if model in app.state.MODELS:
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+ )
+
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+ log.info(f"url: {url}")
+
+ r = requests.request(
+ method="POST",
+ url=f"{url}/api/embed",
+ headers={"Content-Type": "application/json"},
+ data=form_data.model_dump_json(exclude_none=True).encode(),
+ )
+ try:
+ r.raise_for_status()
+
+ data = r.json()
+
+ log.info(f"generate_ollama_batch_embeddings {data}")
+
+ if "embeddings" in data:
+ return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
@@ -788,8 +763,7 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
):
payload = {**form_data.model_dump(exclude_none=True)}
- log.debug(f"{payload = }")
-
+ log.debug(f"generate_chat_completion() - 1.payload = {payload}")
if "metadata" in payload:
del payload["metadata"]
@@ -824,7 +798,7 @@ async def generate_chat_completion(
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
- log.debug(payload)
+ log.debug(f"generate_chat_completion() - 2.payload = {payload}")
return await post_streaming_url(
f"{url}/api/chat",
diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py
index b3e2ed7c6..c80b2011d 100644
--- a/backend/open_webui/apps/retrieval/main.py
+++ b/backend/open_webui/apps/retrieval/main.py
@@ -63,7 +63,7 @@ from open_webui.config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
- RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ RAG_EMBEDDING_BATCH_SIZE,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
-app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
+app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
app.add_middleware(
@@ -267,7 +267,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
- "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
}
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
- "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
- batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
+ embedding_batch_size: Optional[int] = 1
@app.post("/embedding/update")
@@ -320,11 +320,7 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
- form_data.openai_config.batch_size
- if form_data.openai_config.batch_size
- else 1
- )
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -334,17 +330,17 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
- "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
- app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
embeddings = embedding_function(
diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py
index 0fe206c96..b7d746bfa 100644
--- a/backend/open_webui/apps/retrieval/utils.py
+++ b/backend/open_webui/apps/retrieval/utils.py
@@ -12,8 +12,8 @@ from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
- GenerateEmbeddingsForm,
- generate_ollama_embeddings,
+ GenerateEmbedForm,
+ generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@@ -71,7 +71,7 @@ def query_doc(
try:
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
- vectors=[query_embedding],
+ vectors=query_embedding,
limit=k,
)
@@ -265,19 +265,15 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
- batch_size,
+ embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
- GenerateEmbeddingsForm(
- **{
- "model": embedding_model,
- "prompt": query,
- }
- )
+ model=embedding_model,
+ input=query,
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
@@ -289,13 +285,10 @@ def get_embedding_function(
def generate_multiple(query, f):
if isinstance(query, list):
- if embedding_engine == "openai":
- embeddings = []
- for i in range(0, len(query), batch_size):
- embeddings.extend(f(query[i : i + batch_size]))
- return embeddings
- else:
- return [f(q) for q in query]
+ embeddings = []
+ for i in range(0, len(query), embedding_batch_size):
+ embeddings.extend(f(query[i : i + embedding_batch_size]))
+ return embeddings
else:
return f(query)
@@ -481,6 +474,21 @@ def generate_openai_batch_embeddings(
return None
+def generate_ollama_embeddings(
+ model: str, input: list[str]
+) -> Optional[list[list[float]]]:
+ if isinstance(input, list):
+ embeddings = generate_ollama_batch_embeddings(
+ GenerateEmbedForm(**{"model": model, "input": input})
+ )
+ else:
+ embeddings = generate_ollama_batch_embeddings(
+ GenerateEmbedForm(**{"model": model, "input": [input]})
+ )
+
+ return embeddings["embeddings"]
+
+
import operator
from typing import Optional, Sequence
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index bfc9a4ded..22b3d385b 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -986,10 +986,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
-RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
- "RAG_EMBEDDING_OPENAI_BATCH_SIZE",
- "rag.embedding_openai_batch_size",
- int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
+RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
+ "RAG_EMBEDDING_BATCH_SIZE",
+ "rag.embedding_batch_size",
+ int(
+ os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
+ or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
+ ),
)
RAG_RERANKING_MODEL = PersistentConfig(
diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts
index 9f49e9c0f..6c6b18b9f 100644
--- a/src/lib/apis/retrieval/index.ts
+++ b/src/lib/apis/retrieval/index.ts
@@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => {
type OpenAIConfigForm = {
key: string;
url: string;
- batch_size: number;
};
type EmbeddingModelUpdateForm = {
openai_config?: OpenAIConfigForm;
embedding_engine: string;
embedding_model: string;
+ embedding_batch_size?: number;
};
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte
index d6f7dc987..d94146c7d 100644
--- a/src/lib/components/admin/Settings/Documents.svelte
+++ b/src/lib/components/admin/Settings/Documents.svelte
@@ -38,6 +38,7 @@
let embeddingEngine = '';
let embeddingModel = '';
+ let embeddingBatchSize = 1;
let rerankingModel = '';
let fileMaxSize = null;
@@ -53,7 +54,6 @@
let OpenAIKey = '';
let OpenAIUrl = '';
- let OpenAIBatchSize = 1;
let querySettings = {
template: '',
@@ -100,12 +100,16 @@
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_model: embeddingModel,
+ ...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
+ ? {
+ embedding_batch_size: embeddingBatchSize
+ }
+ : {}),
...(embeddingEngine === 'openai'
? {
openai_config: {
key: OpenAIKey,
- url: OpenAIUrl,
- batch_size: OpenAIBatchSize
+ url: OpenAIUrl
}
}
: {})
@@ -193,10 +197,10 @@
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
+ embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url;
- OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1;
}
};
@@ -309,6 +313,8 @@