Merge pull request #4715 from Peter-De-Ath/ollama-batch-embeddings

feat: support ollama batch processing embeddings
This commit is contained in:
Timothy Jaeryang Baek 2024-10-07 18:32:08 -07:00 committed by GitHub
commit d05c2f56ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 119 additions and 132 deletions

View File

@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel):
model: str
input: str
truncate: Optional[bool]
input: list[str]
truncate: Optional[bool] = None
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None
@ -560,48 +560,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings")
@ -611,48 +570,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings(
@ -692,7 +610,64 @@ def generate_ollama_embeddings(
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
return data["embedding"]
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
def generate_ollama_batch_embeddings(
form_data: GenerateEmbedForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
data = r.json()
log.info(f"generate_ollama_batch_embeddings {data}")
if "embeddings" in data:
return data
else:
raise Exception("Something went wrong :/")
except Exception as e:
@ -788,8 +763,7 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
):
payload = {**form_data.model_dump(exclude_none=True)}
log.debug(f"{payload = }")
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
if "metadata" in payload:
del payload["metadata"]
@ -824,7 +798,7 @@ async def generate_chat_completion(
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
log.debug(payload)
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
return await post_streaming_url(
f"{url}/api/chat",

View File

@ -63,7 +63,7 @@ from open_webui.config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
app.add_middleware(
@ -267,7 +267,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
}
@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@app.post("/embedding/update")
@ -320,11 +320,7 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@ -334,17 +330,17 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
return {
"status": True,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@ -690,7 +686,7 @@ def save_docs_to_vector_db(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
embeddings = embedding_function(

View File

@ -12,8 +12,8 @@ from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@ -71,7 +71,7 @@ def query_doc(
try:
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
vectors=query_embedding,
limit=k,
)
@ -265,19 +265,15 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
batch_size,
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
model=embedding_model,
input=query,
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
@ -289,13 +285,10 @@ def get_embedding_function(
def generate_multiple(query, f):
if isinstance(query, list):
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(f(query[i : i + embedding_batch_size]))
return embeddings
else:
return f(query)
@ -481,6 +474,21 @@ def generate_openai_batch_embeddings(
return None
def generate_ollama_embeddings(
model: str, input: list[str]
) -> Optional[list[list[float]]]:
if isinstance(input, list):
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": input})
)
else:
embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [input]})
)
return embeddings["embeddings"]
import operator
from typing import Optional, Sequence

View File

@ -986,10 +986,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
"rag.embedding_openai_batch_size",
int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_BATCH_SIZE",
"rag.embedding_batch_size",
int(
os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
),
)
RAG_RERANKING_MODEL = PersistentConfig(

View File

@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => {
type OpenAIConfigForm = {
key: string;
url: string;
batch_size: number;
};
type EmbeddingModelUpdateForm = {
openai_config?: OpenAIConfigForm;
embedding_engine: string;
embedding_model: string;
embedding_batch_size?: number;
};
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {

View File

@ -38,6 +38,7 @@
let embeddingEngine = '';
let embeddingModel = '';
let embeddingBatchSize = 1;
let rerankingModel = '';
let fileMaxSize = null;
@ -53,7 +54,6 @@
let OpenAIKey = '';
let OpenAIUrl = '';
let OpenAIBatchSize = 1;
let querySettings = {
template: '',
@ -100,12 +100,16 @@
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_model: embeddingModel,
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
? {
embedding_batch_size: embeddingBatchSize
}
: {}),
...(embeddingEngine === 'openai'
? {
openai_config: {
key: OpenAIKey,
url: OpenAIUrl,
batch_size: OpenAIBatchSize
url: OpenAIUrl
}
}
: {})
@ -193,10 +197,10 @@
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url;
OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1;
}
};
@ -309,6 +313,8 @@
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
</div>
{/if}
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
<div class="flex mt-0.5 space-x-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
<div class=" flex-1">
@ -318,13 +324,13 @@
min="1"
max="2048"
step="1"
bind:value={OpenAIBatchSize}
bind:value={embeddingBatchSize}
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
/>
</div>
<div class="">
<input
bind:value={OpenAIBatchSize}
bind:value={embeddingBatchSize}
type="number"
class=" bg-transparent text-center w-14"
min="-2"