mirror of
https://github.com/open-webui/open-webui
synced 2024-11-06 16:59:42 +00:00
Merge pull request #4715 from Peter-De-Ath/ollama-batch-embeddings
feat: support ollama batch processing embeddings
This commit is contained in:
commit
d05c2f56ea
@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
|
||||
|
||||
class GenerateEmbedForm(BaseModel):
|
||||
model: str
|
||||
input: str
|
||||
truncate: Optional[bool]
|
||||
input: list[str]
|
||||
truncate: Optional[bool] = None
|
||||
options: Optional[dict] = None
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
|
||||
@ -560,48 +560,7 @@ async def generate_embeddings(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embed",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except Exception:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return generate_ollama_batch_embeddings(form_data, url_idx)
|
||||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
@ -611,48 +570,7 @@ async def generate_embeddings(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embeddings",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except Exception:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
@ -692,7 +610,64 @@ def generate_ollama_embeddings(
|
||||
log.info(f"generate_ollama_embeddings {data}")
|
||||
|
||||
if "embedding" in data:
|
||||
return data["embedding"]
|
||||
return data
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except Exception:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
form_data: GenerateEmbedForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embed",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
log.info(f"generate_ollama_batch_embeddings {data}")
|
||||
|
||||
if "embeddings" in data:
|
||||
return data
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
except Exception as e:
|
||||
@ -788,8 +763,7 @@ async def generate_chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
payload = {**form_data.model_dump(exclude_none=True)}
|
||||
log.debug(f"{payload = }")
|
||||
|
||||
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
@ -824,7 +798,7 @@ async def generate_chat_completion(
|
||||
|
||||
url = get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
log.debug(payload)
|
||||
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/chat",
|
||||
|
@ -63,7 +63,7 @@ from open_webui.config import (
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
RAG_EMBEDDING_BATCH_SIZE,
|
||||
RAG_FILE_MAX_COUNT,
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@ -267,7 +267,7 @@ async def get_status():
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
}
|
||||
|
||||
|
||||
@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"status": True,
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
|
||||
|
||||
@app.post("/embedding/update")
|
||||
@ -320,11 +320,7 @@ async def update_embedding_config(
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
||||
form_data.openai_config.batch_size
|
||||
if form_data.openai_config.batch_size
|
||||
else 1
|
||||
)
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
@ -334,17 +330,17 @@ async def update_embedding_config(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
@ -690,7 +686,7 @@ def save_docs_to_vector_db(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
|
@ -12,8 +12,8 @@ from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbedForm,
|
||||
generate_ollama_batch_embeddings,
|
||||
)
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
@ -71,7 +71,7 @@ def query_doc(
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=collection_name,
|
||||
vectors=[query_embedding],
|
||||
vectors=query_embedding,
|
||||
limit=k,
|
||||
)
|
||||
|
||||
@ -265,19 +265,15 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
if embedding_engine == "ollama":
|
||||
func = lambda query: generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
model=embedding_model,
|
||||
input=query,
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
func = lambda query: generate_openai_embeddings(
|
||||
@ -289,13 +285,10 @@ def get_embedding_function(
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(f(query[i : i + embedding_batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
@ -481,6 +474,21 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
model: str, input: list[str]
|
||||
) -> Optional[list[list[float]]]:
|
||||
if isinstance(input, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": input})
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": [input]})
|
||||
)
|
||||
|
||||
return embeddings["embeddings"]
|
||||
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
|
@ -986,10 +986,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
|
||||
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
|
||||
"rag.embedding_openai_batch_size",
|
||||
int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
|
||||
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
|
||||
"RAG_EMBEDDING_BATCH_SIZE",
|
||||
"rag.embedding_batch_size",
|
||||
int(
|
||||
os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
|
||||
or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
|
||||
),
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL = PersistentConfig(
|
||||
|
@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => {
|
||||
type OpenAIConfigForm = {
|
||||
key: string;
|
||||
url: string;
|
||||
batch_size: number;
|
||||
};
|
||||
|
||||
type EmbeddingModelUpdateForm = {
|
||||
openai_config?: OpenAIConfigForm;
|
||||
embedding_engine: string;
|
||||
embedding_model: string;
|
||||
embedding_batch_size?: number;
|
||||
};
|
||||
|
||||
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
||||
|
@ -38,6 +38,7 @@
|
||||
|
||||
let embeddingEngine = '';
|
||||
let embeddingModel = '';
|
||||
let embeddingBatchSize = 1;
|
||||
let rerankingModel = '';
|
||||
|
||||
let fileMaxSize = null;
|
||||
@ -53,7 +54,6 @@
|
||||
|
||||
let OpenAIKey = '';
|
||||
let OpenAIUrl = '';
|
||||
let OpenAIBatchSize = 1;
|
||||
|
||||
let querySettings = {
|
||||
template: '',
|
||||
@ -100,12 +100,16 @@
|
||||
const res = await updateEmbeddingConfig(localStorage.token, {
|
||||
embedding_engine: embeddingEngine,
|
||||
embedding_model: embeddingModel,
|
||||
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
|
||||
? {
|
||||
embedding_batch_size: embeddingBatchSize
|
||||
}
|
||||
: {}),
|
||||
...(embeddingEngine === 'openai'
|
||||
? {
|
||||
openai_config: {
|
||||
key: OpenAIKey,
|
||||
url: OpenAIUrl,
|
||||
batch_size: OpenAIBatchSize
|
||||
url: OpenAIUrl
|
||||
}
|
||||
}
|
||||
: {})
|
||||
@ -193,10 +197,10 @@
|
||||
if (embeddingConfig) {
|
||||
embeddingEngine = embeddingConfig.embedding_engine;
|
||||
embeddingModel = embeddingConfig.embedding_model;
|
||||
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
|
||||
|
||||
OpenAIKey = embeddingConfig.openai_config.key;
|
||||
OpenAIUrl = embeddingConfig.openai_config.url;
|
||||
OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1;
|
||||
}
|
||||
};
|
||||
|
||||
@ -309,6 +313,8 @@
|
||||
|
||||
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
||||
</div>
|
||||
{/if}
|
||||
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
|
||||
<div class="flex mt-0.5 space-x-2">
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
|
||||
<div class=" flex-1">
|
||||
@ -318,13 +324,13 @@
|
||||
min="1"
|
||||
max="2048"
|
||||
step="1"
|
||||
bind:value={OpenAIBatchSize}
|
||||
bind:value={embeddingBatchSize}
|
||||
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
|
||||
/>
|
||||
</div>
|
||||
<div class="">
|
||||
<input
|
||||
bind:value={OpenAIBatchSize}
|
||||
bind:value={embeddingBatchSize}
|
||||
type="number"
|
||||
class=" bg-transparent text-center w-14"
|
||||
min="-2"
|
||||
|
Loading…
Reference in New Issue
Block a user