mirror of
https://github.com/open-webui/open-webui
synced 2025-03-24 22:49:22 +00:00
Merge pull request #4715 from Peter-De-Ath/ollama-batch-embeddings
feat: support ollama batch processing embeddings
This commit is contained in:
commit
d05c2f56ea
@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
|
|||||||
|
|
||||||
class GenerateEmbedForm(BaseModel):
|
class GenerateEmbedForm(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
input: str
|
input: list[str]
|
||||||
truncate: Optional[bool]
|
truncate: Optional[bool] = None
|
||||||
options: Optional[dict] = None
|
options: Optional[dict] = None
|
||||||
keep_alive: Optional[Union[int, str]] = None
|
keep_alive: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
@ -560,48 +560,7 @@ async def generate_embeddings(
|
|||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
return generate_ollama_batch_embeddings(form_data, url_idx)
|
||||||
model = form_data.model
|
|
||||||
|
|
||||||
if ":" not in model:
|
|
||||||
model = f"{model}:latest"
|
|
||||||
|
|
||||||
if model in app.state.MODELS:
|
|
||||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
||||||
)
|
|
||||||
|
|
||||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
||||||
log.info(f"url: {url}")
|
|
||||||
|
|
||||||
r = requests.request(
|
|
||||||
method="POST",
|
|
||||||
url=f"{url}/api/embed",
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
return r.json()
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
error_detail = "Open WebUI: Server Connection Error"
|
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = r.json()
|
|
||||||
if "error" in res:
|
|
||||||
error_detail = f"Ollama: {res['error']}"
|
|
||||||
except Exception:
|
|
||||||
error_detail = f"Ollama: {e}"
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=r.status_code if r else 500,
|
|
||||||
detail=error_detail,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/embeddings")
|
@app.post("/api/embeddings")
|
||||||
@ -611,48 +570,7 @@ async def generate_embeddings(
|
|||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
||||||
model = form_data.model
|
|
||||||
|
|
||||||
if ":" not in model:
|
|
||||||
model = f"{model}:latest"
|
|
||||||
|
|
||||||
if model in app.state.MODELS:
|
|
||||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
||||||
)
|
|
||||||
|
|
||||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
||||||
log.info(f"url: {url}")
|
|
||||||
|
|
||||||
r = requests.request(
|
|
||||||
method="POST",
|
|
||||||
url=f"{url}/api/embeddings",
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
return r.json()
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
error_detail = "Open WebUI: Server Connection Error"
|
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = r.json()
|
|
||||||
if "error" in res:
|
|
||||||
error_detail = f"Ollama: {res['error']}"
|
|
||||||
except Exception:
|
|
||||||
error_detail = f"Ollama: {e}"
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=r.status_code if r else 500,
|
|
||||||
detail=error_detail,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_ollama_embeddings(
|
def generate_ollama_embeddings(
|
||||||
@ -692,7 +610,64 @@ def generate_ollama_embeddings(
|
|||||||
log.info(f"generate_ollama_embeddings {data}")
|
log.info(f"generate_ollama_embeddings {data}")
|
||||||
|
|
||||||
if "embedding" in data:
|
if "embedding" in data:
|
||||||
return data["embedding"]
|
return data
|
||||||
|
else:
|
||||||
|
raise Exception("Something went wrong :/")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
error_detail = "Open WebUI: Server Connection Error"
|
||||||
|
if r is not None:
|
||||||
|
try:
|
||||||
|
res = r.json()
|
||||||
|
if "error" in res:
|
||||||
|
error_detail = f"Ollama: {res['error']}"
|
||||||
|
except Exception:
|
||||||
|
error_detail = f"Ollama: {e}"
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=r.status_code if r else 500,
|
||||||
|
detail=error_detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ollama_batch_embeddings(
|
||||||
|
form_data: GenerateEmbedForm,
|
||||||
|
url_idx: Optional[int] = None,
|
||||||
|
):
|
||||||
|
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||||
|
|
||||||
|
if url_idx is None:
|
||||||
|
model = form_data.model
|
||||||
|
|
||||||
|
if ":" not in model:
|
||||||
|
model = f"{model}:latest"
|
||||||
|
|
||||||
|
if model in app.state.MODELS:
|
||||||
|
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||||
|
)
|
||||||
|
|
||||||
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
|
r = requests.request(
|
||||||
|
method="POST",
|
||||||
|
url=f"{url}/api/embed",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
data = r.json()
|
||||||
|
|
||||||
|
log.info(f"generate_ollama_batch_embeddings {data}")
|
||||||
|
|
||||||
|
if "embeddings" in data:
|
||||||
|
return data
|
||||||
else:
|
else:
|
||||||
raise Exception("Something went wrong :/")
|
raise Exception("Something went wrong :/")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -788,8 +763,7 @@ async def generate_chat_completion(
|
|||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
payload = {**form_data.model_dump(exclude_none=True)}
|
payload = {**form_data.model_dump(exclude_none=True)}
|
||||||
log.debug(f"{payload = }")
|
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
|
||||||
|
|
||||||
if "metadata" in payload:
|
if "metadata" in payload:
|
||||||
del payload["metadata"]
|
del payload["metadata"]
|
||||||
|
|
||||||
@ -824,7 +798,7 @@ async def generate_chat_completion(
|
|||||||
|
|
||||||
url = get_ollama_url(url_idx, payload["model"])
|
url = get_ollama_url(url_idx, payload["model"])
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
log.debug(payload)
|
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
||||||
|
|
||||||
return await post_streaming_url(
|
return await post_streaming_url(
|
||||||
f"{url}/api/chat",
|
f"{url}/api/chat",
|
||||||
|
@ -63,7 +63,7 @@ from open_webui.config import (
|
|||||||
RAG_EMBEDDING_MODEL,
|
RAG_EMBEDDING_MODEL,
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
RAG_EMBEDDING_BATCH_SIZE,
|
||||||
RAG_FILE_MAX_COUNT,
|
RAG_FILE_MAX_COUNT,
|
||||||
RAG_FILE_MAX_SIZE,
|
RAG_FILE_MAX_SIZE,
|
||||||
RAG_OPENAI_API_BASE_URL,
|
RAG_OPENAI_API_BASE_URL,
|
||||||
@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|||||||
|
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.config.OPENAI_API_KEY,
|
app.state.config.OPENAI_API_KEY,
|
||||||
app.state.config.OPENAI_API_BASE_URL,
|
app.state.config.OPENAI_API_BASE_URL,
|
||||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -267,7 +267,7 @@ async def get_status():
|
|||||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||||
"key": app.state.config.OPENAI_API_KEY,
|
"key": app.state.config.OPENAI_API_KEY,
|
||||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
|||||||
class OpenAIConfigForm(BaseModel):
|
class OpenAIConfigForm(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
key: str
|
key: str
|
||||||
batch_size: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelUpdateForm(BaseModel):
|
class EmbeddingModelUpdateForm(BaseModel):
|
||||||
openai_config: Optional[OpenAIConfigForm] = None
|
openai_config: Optional[OpenAIConfigForm] = None
|
||||||
embedding_engine: str
|
embedding_engine: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
|
embedding_batch_size: Optional[int] = 1
|
||||||
|
|
||||||
|
|
||||||
@app.post("/embedding/update")
|
@app.post("/embedding/update")
|
||||||
@ -320,11 +320,7 @@ async def update_embedding_config(
|
|||||||
if form_data.openai_config is not None:
|
if form_data.openai_config is not None:
|
||||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||||
form_data.openai_config.batch_size
|
|
||||||
if form_data.openai_config.batch_size
|
|
||||||
else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||||
|
|
||||||
@ -334,17 +330,17 @@ async def update_embedding_config(
|
|||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.config.OPENAI_API_KEY,
|
app.state.config.OPENAI_API_KEY,
|
||||||
app.state.config.OPENAI_API_BASE_URL,
|
app.state.config.OPENAI_API_BASE_URL,
|
||||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||||
"key": app.state.config.OPENAI_API_KEY,
|
"key": app.state.config.OPENAI_API_KEY,
|
||||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -690,7 +686,7 @@ def save_docs_to_vector_db(
|
|||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.config.OPENAI_API_KEY,
|
app.state.config.OPENAI_API_KEY,
|
||||||
app.state.config.OPENAI_API_BASE_URL,
|
app.state.config.OPENAI_API_BASE_URL,
|
||||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = embedding_function(
|
embeddings = embedding_function(
|
||||||
|
@ -12,8 +12,8 @@ from langchain_core.documents import Document
|
|||||||
|
|
||||||
|
|
||||||
from open_webui.apps.ollama.main import (
|
from open_webui.apps.ollama.main import (
|
||||||
GenerateEmbeddingsForm,
|
GenerateEmbedForm,
|
||||||
generate_ollama_embeddings,
|
generate_ollama_batch_embeddings,
|
||||||
)
|
)
|
||||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||||
from open_webui.utils.misc import get_last_user_message
|
from open_webui.utils.misc import get_last_user_message
|
||||||
@ -71,7 +71,7 @@ def query_doc(
|
|||||||
try:
|
try:
|
||||||
result = VECTOR_DB_CLIENT.search(
|
result = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
vectors=[query_embedding],
|
vectors=query_embedding,
|
||||||
limit=k,
|
limit=k,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -265,19 +265,15 @@ def get_embedding_function(
|
|||||||
embedding_function,
|
embedding_function,
|
||||||
openai_key,
|
openai_key,
|
||||||
openai_url,
|
openai_url,
|
||||||
batch_size,
|
embedding_batch_size,
|
||||||
):
|
):
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
return lambda query: embedding_function.encode(query).tolist()
|
return lambda query: embedding_function.encode(query).tolist()
|
||||||
elif embedding_engine in ["ollama", "openai"]:
|
elif embedding_engine in ["ollama", "openai"]:
|
||||||
if embedding_engine == "ollama":
|
if embedding_engine == "ollama":
|
||||||
func = lambda query: generate_ollama_embeddings(
|
func = lambda query: generate_ollama_embeddings(
|
||||||
GenerateEmbeddingsForm(
|
model=embedding_model,
|
||||||
**{
|
input=query,
|
||||||
"model": embedding_model,
|
|
||||||
"prompt": query,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
elif embedding_engine == "openai":
|
elif embedding_engine == "openai":
|
||||||
func = lambda query: generate_openai_embeddings(
|
func = lambda query: generate_openai_embeddings(
|
||||||
@ -289,13 +285,10 @@ def get_embedding_function(
|
|||||||
|
|
||||||
def generate_multiple(query, f):
|
def generate_multiple(query, f):
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
if embedding_engine == "openai":
|
embeddings = []
|
||||||
embeddings = []
|
for i in range(0, len(query), embedding_batch_size):
|
||||||
for i in range(0, len(query), batch_size):
|
embeddings.extend(f(query[i : i + embedding_batch_size]))
|
||||||
embeddings.extend(f(query[i : i + batch_size]))
|
return embeddings
|
||||||
return embeddings
|
|
||||||
else:
|
|
||||||
return [f(q) for q in query]
|
|
||||||
else:
|
else:
|
||||||
return f(query)
|
return f(query)
|
||||||
|
|
||||||
@ -481,6 +474,21 @@ def generate_openai_batch_embeddings(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ollama_embeddings(
|
||||||
|
model: str, input: list[str]
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
if isinstance(input, list):
|
||||||
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
|
GenerateEmbedForm(**{"model": model, "input": input})
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
|
GenerateEmbedForm(**{"model": model, "input": [input]})
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings["embeddings"]
|
||||||
|
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
@ -986,10 +986,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
|||||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
|
RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
|
||||||
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
|
"RAG_EMBEDDING_BATCH_SIZE",
|
||||||
"rag.embedding_openai_batch_size",
|
"rag.embedding_batch_size",
|
||||||
int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
|
int(
|
||||||
|
os.environ.get("RAG_EMBEDDING_BATCH_SIZE")
|
||||||
|
or os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
RAG_RERANKING_MODEL = PersistentConfig(
|
RAG_RERANKING_MODEL = PersistentConfig(
|
||||||
|
@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => {
|
|||||||
type OpenAIConfigForm = {
|
type OpenAIConfigForm = {
|
||||||
key: string;
|
key: string;
|
||||||
url: string;
|
url: string;
|
||||||
batch_size: number;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
type EmbeddingModelUpdateForm = {
|
type EmbeddingModelUpdateForm = {
|
||||||
openai_config?: OpenAIConfigForm;
|
openai_config?: OpenAIConfigForm;
|
||||||
embedding_engine: string;
|
embedding_engine: string;
|
||||||
embedding_model: string;
|
embedding_model: string;
|
||||||
|
embedding_batch_size?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
||||||
|
@ -38,6 +38,7 @@
|
|||||||
|
|
||||||
let embeddingEngine = '';
|
let embeddingEngine = '';
|
||||||
let embeddingModel = '';
|
let embeddingModel = '';
|
||||||
|
let embeddingBatchSize = 1;
|
||||||
let rerankingModel = '';
|
let rerankingModel = '';
|
||||||
|
|
||||||
let fileMaxSize = null;
|
let fileMaxSize = null;
|
||||||
@ -53,7 +54,6 @@
|
|||||||
|
|
||||||
let OpenAIKey = '';
|
let OpenAIKey = '';
|
||||||
let OpenAIUrl = '';
|
let OpenAIUrl = '';
|
||||||
let OpenAIBatchSize = 1;
|
|
||||||
|
|
||||||
let querySettings = {
|
let querySettings = {
|
||||||
template: '',
|
template: '',
|
||||||
@ -100,12 +100,16 @@
|
|||||||
const res = await updateEmbeddingConfig(localStorage.token, {
|
const res = await updateEmbeddingConfig(localStorage.token, {
|
||||||
embedding_engine: embeddingEngine,
|
embedding_engine: embeddingEngine,
|
||||||
embedding_model: embeddingModel,
|
embedding_model: embeddingModel,
|
||||||
|
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
|
||||||
|
? {
|
||||||
|
embedding_batch_size: embeddingBatchSize
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
...(embeddingEngine === 'openai'
|
...(embeddingEngine === 'openai'
|
||||||
? {
|
? {
|
||||||
openai_config: {
|
openai_config: {
|
||||||
key: OpenAIKey,
|
key: OpenAIKey,
|
||||||
url: OpenAIUrl,
|
url: OpenAIUrl
|
||||||
batch_size: OpenAIBatchSize
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
: {})
|
: {})
|
||||||
@ -193,10 +197,10 @@
|
|||||||
if (embeddingConfig) {
|
if (embeddingConfig) {
|
||||||
embeddingEngine = embeddingConfig.embedding_engine;
|
embeddingEngine = embeddingConfig.embedding_engine;
|
||||||
embeddingModel = embeddingConfig.embedding_model;
|
embeddingModel = embeddingConfig.embedding_model;
|
||||||
|
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
|
||||||
|
|
||||||
OpenAIKey = embeddingConfig.openai_config.key;
|
OpenAIKey = embeddingConfig.openai_config.key;
|
||||||
OpenAIUrl = embeddingConfig.openai_config.url;
|
OpenAIUrl = embeddingConfig.openai_config.url;
|
||||||
OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -309,6 +313,8 @@
|
|||||||
|
|
||||||
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
||||||
</div>
|
</div>
|
||||||
|
{/if}
|
||||||
|
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
|
||||||
<div class="flex mt-0.5 space-x-2">
|
<div class="flex mt-0.5 space-x-2">
|
||||||
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
|
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
|
||||||
<div class=" flex-1">
|
<div class=" flex-1">
|
||||||
@ -318,13 +324,13 @@
|
|||||||
min="1"
|
min="1"
|
||||||
max="2048"
|
max="2048"
|
||||||
step="1"
|
step="1"
|
||||||
bind:value={OpenAIBatchSize}
|
bind:value={embeddingBatchSize}
|
||||||
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
|
class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="">
|
<div class="">
|
||||||
<input
|
<input
|
||||||
bind:value={OpenAIBatchSize}
|
bind:value={embeddingBatchSize}
|
||||||
type="number"
|
type="number"
|
||||||
class=" bg-transparent text-center w-14"
|
class=" bg-transparent text-center w-14"
|
||||||
min="-2"
|
min="-2"
|
||||||
|
Loading…
Reference in New Issue
Block a user