mirror of
https://github.com/open-webui/open-webui
synced 2025-06-10 00:17:52 +00:00
Merge pull request #14370 from daw/feat/add-azure-openai-embeddings-option
feat:Add Azure OpenAI embedding support
This commit is contained in:
commit
ff353578db
@ -2184,6 +2184,27 @@ RAG_OPENAI_API_KEY = PersistentConfig(
|
|||||||
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAG_AZURE_OPENAI_BASE_URL = PersistentConfig(
|
||||||
|
"RAG_AZURE_OPENAI_BASE_URL",
|
||||||
|
"rag.azure_openai.base_url",
|
||||||
|
os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""),
|
||||||
|
)
|
||||||
|
RAG_AZURE_OPENAI_API_KEY = PersistentConfig(
|
||||||
|
"RAG_AZURE_OPENAI_API_KEY",
|
||||||
|
"rag.azure_openai.api_key",
|
||||||
|
os.getenv("RAG_AZURE_OPENAI_API_KEY", ""),
|
||||||
|
)
|
||||||
|
RAG_AZURE_OPENAI_DEPLOYMENT = PersistentConfig(
|
||||||
|
"RAG_AZURE_OPENAI_DEPLOYMENT",
|
||||||
|
"rag.azure_openai.deployment",
|
||||||
|
os.getenv("RAG_AZURE_OPENAI_DEPLOYMENT", ""),
|
||||||
|
)
|
||||||
|
RAG_AZURE_OPENAI_VERSION = PersistentConfig(
|
||||||
|
"RAG_AZURE_OPENAI_VERSION",
|
||||||
|
"rag.azure_openai.version",
|
||||||
|
os.getenv("RAG_AZURE_OPENAI_VERSION", ""),
|
||||||
|
)
|
||||||
|
|
||||||
RAG_OLLAMA_BASE_URL = PersistentConfig(
|
RAG_OLLAMA_BASE_URL = PersistentConfig(
|
||||||
"RAG_OLLAMA_BASE_URL",
|
"RAG_OLLAMA_BASE_URL",
|
||||||
"rag.ollama.url",
|
"rag.ollama.url",
|
||||||
|
@ -207,6 +207,10 @@ from open_webui.config import (
|
|||||||
RAG_FILE_MAX_SIZE,
|
RAG_FILE_MAX_SIZE,
|
||||||
RAG_OPENAI_API_BASE_URL,
|
RAG_OPENAI_API_BASE_URL,
|
||||||
RAG_OPENAI_API_KEY,
|
RAG_OPENAI_API_KEY,
|
||||||
|
RAG_AZURE_OPENAI_BASE_URL,
|
||||||
|
RAG_AZURE_OPENAI_API_KEY,
|
||||||
|
RAG_AZURE_OPENAI_DEPLOYMENT,
|
||||||
|
RAG_AZURE_OPENAI_VERSION,
|
||||||
RAG_OLLAMA_BASE_URL,
|
RAG_OLLAMA_BASE_URL,
|
||||||
RAG_OLLAMA_API_KEY,
|
RAG_OLLAMA_API_KEY,
|
||||||
CHUNK_OVERLAP,
|
CHUNK_OVERLAP,
|
||||||
@ -717,6 +721,11 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
|||||||
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||||
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||||
|
|
||||||
|
app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
|
||||||
|
app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
|
||||||
|
app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = RAG_AZURE_OPENAI_DEPLOYMENT
|
||||||
|
app.state.config.RAG_AZURE_OPENAI_VERSION = RAG_AZURE_OPENAI_VERSION
|
||||||
|
|
||||||
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
||||||
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
||||||
|
|
||||||
@ -811,14 +820,32 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||||||
(
|
(
|
||||||
app.state.config.RAG_OPENAI_API_BASE_URL
|
app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else app.state.config.RAG_OLLAMA_BASE_URL
|
else (
|
||||||
|
app.state.config.RAG_OLLAMA_BASE_URL
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||||
|
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||||
|
)
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
app.state.config.RAG_OPENAI_API_KEY
|
app.state.config.RAG_OPENAI_API_KEY
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else app.state.config.RAG_OLLAMA_API_KEY
|
else (
|
||||||
|
app.state.config.RAG_OLLAMA_API_KEY
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||||
|
else app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||||
|
)
|
||||||
),
|
),
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
(
|
||||||
|
app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
(
|
||||||
|
app.state.config.RAG_AZURE_OPENAI_VERSION
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
|||||||
import requests
|
import requests
|
||||||
import hashlib
|
import hashlib
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import time
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||||
@ -400,12 +401,14 @@ def get_embedding_function(
|
|||||||
url,
|
url,
|
||||||
key,
|
key,
|
||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
|
deployment=None,
|
||||||
|
version=None,
|
||||||
):
|
):
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
return lambda query, prefix=None, user=None: embedding_function.encode(
|
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||||
query, **({"prompt": prefix} if prefix else {})
|
query, **({"prompt": prefix} if prefix else {})
|
||||||
).tolist()
|
).tolist()
|
||||||
elif embedding_engine in ["ollama", "openai"]:
|
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||||
func = lambda query, prefix=None, user=None: generate_embeddings(
|
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||||
engine=embedding_engine,
|
engine=embedding_engine,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
@ -414,6 +417,8 @@ def get_embedding_function(
|
|||||||
url=url,
|
url=url,
|
||||||
key=key,
|
key=key,
|
||||||
user=user,
|
user=user,
|
||||||
|
deployment=deployment,
|
||||||
|
version=version,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_multiple(query, prefix, user, func):
|
def generate_multiple(query, prefix, user, func):
|
||||||
@ -697,6 +702,61 @@ def generate_openai_batch_embeddings(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_azure_openai_batch_embeddings(
|
||||||
|
deployment: str,
|
||||||
|
texts: list[str],
|
||||||
|
url: str,
|
||||||
|
key: str = "",
|
||||||
|
model: str = "",
|
||||||
|
version: str = "",
|
||||||
|
prefix: str = None,
|
||||||
|
user: UserModel = None,
|
||||||
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
log.debug(
|
||||||
|
f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
|
||||||
|
)
|
||||||
|
json_data = {"input": texts, "model": model}
|
||||||
|
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||||
|
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||||
|
|
||||||
|
url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
r = requests.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": key,
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
if r.status_code == 429:
|
||||||
|
retry = float(r.headers.get("Retry-After", "1"))
|
||||||
|
time.sleep(retry)
|
||||||
|
continue
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return [elem["embedding"] for elem in data["data"]]
|
||||||
|
else:
|
||||||
|
raise Exception("Something went wrong :/")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate_ollama_batch_embeddings(
|
def generate_ollama_batch_embeddings(
|
||||||
model: str,
|
model: str,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@ -794,6 +854,32 @@ def generate_embeddings(
|
|||||||
model, [text], url, key, prefix, user
|
model, [text], url, key, prefix, user
|
||||||
)
|
)
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
elif engine == "azure_openai":
|
||||||
|
deployment = kwargs.get("deployment", "")
|
||||||
|
version = kwargs.get("version", "")
|
||||||
|
if isinstance(text, list):
|
||||||
|
embeddings = generate_azure_openai_batch_embeddings(
|
||||||
|
deployment,
|
||||||
|
text,
|
||||||
|
url,
|
||||||
|
key,
|
||||||
|
model,
|
||||||
|
version,
|
||||||
|
prefix,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embeddings = generate_azure_openai_batch_embeddings(
|
||||||
|
deployment,
|
||||||
|
[text],
|
||||||
|
url,
|
||||||
|
key,
|
||||||
|
model,
|
||||||
|
version,
|
||||||
|
prefix,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
|
@ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
|||||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||||
},
|
},
|
||||||
|
"azure_openai_config": {
|
||||||
|
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
||||||
|
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
||||||
|
"deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
|
||||||
|
"version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel):
|
|||||||
key: str
|
key: str
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIConfigForm(BaseModel):
|
||||||
|
url: str
|
||||||
|
key: str
|
||||||
|
deployment: str
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelUpdateForm(BaseModel):
|
class EmbeddingModelUpdateForm(BaseModel):
|
||||||
openai_config: Optional[OpenAIConfigForm] = None
|
openai_config: Optional[OpenAIConfigForm] = None
|
||||||
ollama_config: Optional[OllamaConfigForm] = None
|
ollama_config: Optional[OllamaConfigForm] = None
|
||||||
|
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
||||||
embedding_engine: str
|
embedding_engine: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_batch_size: Optional[int] = 1
|
embedding_batch_size: Optional[int] = 1
|
||||||
@ -271,7 +285,7 @@ async def update_embedding_config(
|
|||||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||||
|
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]:
|
||||||
if form_data.openai_config is not None:
|
if form_data.openai_config is not None:
|
||||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||||
form_data.openai_config.url
|
form_data.openai_config.url
|
||||||
@ -288,6 +302,20 @@ async def update_embedding_config(
|
|||||||
form_data.ollama_config.key
|
form_data.ollama_config.key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if form_data.azure_openai_config is not None:
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||||
|
form_data.azure_openai_config.url
|
||||||
|
)
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||||
|
form_data.azure_openai_config.key
|
||||||
|
)
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = (
|
||||||
|
form_data.azure_openai_config.deployment
|
||||||
|
)
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_VERSION = (
|
||||||
|
form_data.azure_openai_config.version
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||||
form_data.embedding_batch_size
|
form_data.embedding_batch_size
|
||||||
)
|
)
|
||||||
@ -304,14 +332,32 @@ async def update_embedding_config(
|
|||||||
(
|
(
|
||||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
else (
|
||||||
|
request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||||
|
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||||
|
)
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
request.app.state.config.RAG_OPENAI_API_KEY
|
request.app.state.config.RAG_OPENAI_API_KEY
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
else (
|
||||||
|
request.app.state.config.RAG_OLLAMA_API_KEY
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||||
|
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||||
|
)
|
||||||
),
|
),
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
(
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
(
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_VERSION
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -327,6 +373,12 @@ async def update_embedding_config(
|
|||||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||||
},
|
},
|
||||||
|
"azure_openai_config": {
|
||||||
|
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
||||||
|
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
||||||
|
"deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
|
||||||
|
"version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Problem updating embedding model: {e}")
|
log.exception(f"Problem updating embedding model: {e}")
|
||||||
@ -1129,14 +1181,32 @@ def save_docs_to_vector_db(
|
|||||||
(
|
(
|
||||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
else (
|
||||||
|
request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||||
|
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||||
|
)
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
request.app.state.config.RAG_OPENAI_API_KEY
|
request.app.state.config.RAG_OPENAI_API_KEY
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
else (
|
||||||
|
request.app.state.config.RAG_OLLAMA_API_KEY
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||||
|
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||||
|
)
|
||||||
),
|
),
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
(
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
(
|
||||||
|
request.app.state.config.RAG_AZURE_OPENAI_VERSION
|
||||||
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = embedding_function(
|
embeddings = embedding_function(
|
||||||
|
@ -180,15 +180,23 @@ export const getEmbeddingConfig = async (token: string) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
type OpenAIConfigForm = {
|
type OpenAIConfigForm = {
|
||||||
key: string;
|
key: string;
|
||||||
url: string;
|
url: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
type AzureOpenAIConfigForm = {
|
||||||
|
key: string;
|
||||||
|
url: string;
|
||||||
|
deployment: string;
|
||||||
|
version: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type EmbeddingModelUpdateForm = {
|
type EmbeddingModelUpdateForm = {
|
||||||
openai_config?: OpenAIConfigForm;
|
openai_config?: OpenAIConfigForm;
|
||||||
embedding_engine: string;
|
azure_openai_config?: AzureOpenAIConfigForm;
|
||||||
embedding_model: string;
|
embedding_engine: string;
|
||||||
embedding_batch_size?: number;
|
embedding_model: string;
|
||||||
|
embedding_batch_size?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
||||||
|
@ -43,8 +43,13 @@
|
|||||||
let embeddingBatchSize = 1;
|
let embeddingBatchSize = 1;
|
||||||
let rerankingModel = '';
|
let rerankingModel = '';
|
||||||
|
|
||||||
let OpenAIUrl = '';
|
let OpenAIUrl = '';
|
||||||
let OpenAIKey = '';
|
let OpenAIKey = '';
|
||||||
|
|
||||||
|
let AzureOpenAIUrl = '';
|
||||||
|
let AzureOpenAIKey = '';
|
||||||
|
let AzureOpenAIDeployment = '';
|
||||||
|
let AzureOpenAIVersion = '';
|
||||||
|
|
||||||
let OllamaUrl = '';
|
let OllamaUrl = '';
|
||||||
let OllamaKey = '';
|
let OllamaKey = '';
|
||||||
@ -86,27 +91,40 @@
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((embeddingEngine === 'openai' && OpenAIKey === '') || OpenAIUrl === '') {
|
if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) {
|
||||||
toast.error($i18n.t('OpenAI URL/Key required.'));
|
toast.error($i18n.t('OpenAI URL/Key required.'));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
|
embeddingEngine === 'azure_openai' &&
|
||||||
|
(AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIDeployment === '' || AzureOpenAIVersion === '')
|
||||||
|
) {
|
||||||
|
toast.error($i18n.t('OpenAI URL/Key required.'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
console.debug('Update embedding model attempt:', embeddingModel);
|
console.debug('Update embedding model attempt:', embeddingModel);
|
||||||
|
|
||||||
updateEmbeddingModelLoading = true;
|
updateEmbeddingModelLoading = true;
|
||||||
const res = await updateEmbeddingConfig(localStorage.token, {
|
const res = await updateEmbeddingConfig(localStorage.token, {
|
||||||
embedding_engine: embeddingEngine,
|
embedding_engine: embeddingEngine,
|
||||||
embedding_model: embeddingModel,
|
embedding_model: embeddingModel,
|
||||||
embedding_batch_size: embeddingBatchSize,
|
embedding_batch_size: embeddingBatchSize,
|
||||||
ollama_config: {
|
ollama_config: {
|
||||||
key: OllamaKey,
|
key: OllamaKey,
|
||||||
url: OllamaUrl
|
url: OllamaUrl
|
||||||
},
|
},
|
||||||
openai_config: {
|
openai_config: {
|
||||||
key: OpenAIKey,
|
key: OpenAIKey,
|
||||||
url: OpenAIUrl
|
url: OpenAIUrl
|
||||||
}
|
},
|
||||||
}).catch(async (error) => {
|
azure_openai_config: {
|
||||||
|
key: AzureOpenAIKey,
|
||||||
|
url: AzureOpenAIUrl,
|
||||||
|
deployment: AzureOpenAIDeployment,
|
||||||
|
version: AzureOpenAIVersion
|
||||||
|
}
|
||||||
|
}).catch(async (error) => {
|
||||||
toast.error(`${error}`);
|
toast.error(`${error}`);
|
||||||
await setEmbeddingConfig();
|
await setEmbeddingConfig();
|
||||||
return null;
|
return null;
|
||||||
@ -200,13 +218,18 @@
|
|||||||
embeddingModel = embeddingConfig.embedding_model;
|
embeddingModel = embeddingConfig.embedding_model;
|
||||||
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
|
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
|
||||||
|
|
||||||
OpenAIKey = embeddingConfig.openai_config.key;
|
OpenAIKey = embeddingConfig.openai_config.key;
|
||||||
OpenAIUrl = embeddingConfig.openai_config.url;
|
OpenAIUrl = embeddingConfig.openai_config.url;
|
||||||
|
|
||||||
OllamaKey = embeddingConfig.ollama_config.key;
|
OllamaKey = embeddingConfig.ollama_config.key;
|
||||||
OllamaUrl = embeddingConfig.ollama_config.url;
|
OllamaUrl = embeddingConfig.ollama_config.url;
|
||||||
}
|
|
||||||
};
|
AzureOpenAIKey = embeddingConfig.azure_openai_config.key;
|
||||||
|
AzureOpenAIUrl = embeddingConfig.azure_openai_config.url;
|
||||||
|
AzureOpenAIDeployment = embeddingConfig.azure_openai_config.deployment;
|
||||||
|
AzureOpenAIVersion = embeddingConfig.azure_openai_config.version;
|
||||||
|
}
|
||||||
|
};
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
await setEmbeddingConfig();
|
await setEmbeddingConfig();
|
||||||
|
|
||||||
@ -603,23 +626,26 @@
|
|||||||
bind:value={embeddingEngine}
|
bind:value={embeddingEngine}
|
||||||
placeholder="Select an embedding model engine"
|
placeholder="Select an embedding model engine"
|
||||||
on:change={(e) => {
|
on:change={(e) => {
|
||||||
if (e.target.value === 'ollama') {
|
if (e.target.value === 'ollama') {
|
||||||
embeddingModel = '';
|
embeddingModel = '';
|
||||||
} else if (e.target.value === 'openai') {
|
} else if (e.target.value === 'openai') {
|
||||||
embeddingModel = 'text-embedding-3-small';
|
embeddingModel = 'text-embedding-3-small';
|
||||||
} else if (e.target.value === '') {
|
} else if (e.target.value === 'azure_openai') {
|
||||||
embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
|
embeddingModel = 'text-embedding-3-small';
|
||||||
}
|
} else if (e.target.value === '') {
|
||||||
|
embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
|
<option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
|
||||||
<option value="ollama">{$i18n.t('Ollama')}</option>
|
<option value="ollama">{$i18n.t('Ollama')}</option>
|
||||||
<option value="openai">{$i18n.t('OpenAI')}</option>
|
<option value="openai">{$i18n.t('OpenAI')}</option>
|
||||||
|
<option value="azure_openai">Azure OpenAI</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if embeddingEngine === 'openai'}
|
{#if embeddingEngine === 'openai'}
|
||||||
<div class="my-0.5 flex gap-2 pr-2">
|
<div class="my-0.5 flex gap-2 pr-2">
|
||||||
<input
|
<input
|
||||||
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
||||||
@ -630,7 +656,7 @@
|
|||||||
|
|
||||||
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
||||||
</div>
|
</div>
|
||||||
{:else if embeddingEngine === 'ollama'}
|
{:else if embeddingEngine === 'ollama'}
|
||||||
<div class="my-0.5 flex gap-2 pr-2">
|
<div class="my-0.5 flex gap-2 pr-2">
|
||||||
<input
|
<input
|
||||||
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
||||||
@ -645,7 +671,33 @@
|
|||||||
required={false}
|
required={false}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{:else if embeddingEngine === 'azure_openai'}
|
||||||
|
<div class="my-0.5 flex flex-col gap-2 pr-2 w-full">
|
||||||
|
<div class="flex gap-2">
|
||||||
|
<input
|
||||||
|
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
||||||
|
placeholder={$i18n.t('API Base URL')}
|
||||||
|
bind:value={AzureOpenAIUrl}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={AzureOpenAIKey} />
|
||||||
|
</div>
|
||||||
|
<div class="flex gap-2">
|
||||||
|
<input
|
||||||
|
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
||||||
|
placeholder="Deployment"
|
||||||
|
bind:value={AzureOpenAIDeployment}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<input
|
||||||
|
class="flex-1 w-full text-sm bg-transparent outline-hidden"
|
||||||
|
placeholder="Version"
|
||||||
|
bind:value={AzureOpenAIVersion}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class=" mb-2.5 flex flex-col w-full">
|
<div class=" mb-2.5 flex flex-col w-full">
|
||||||
@ -741,7 +793,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
|
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'}
|
||||||
<div class=" mb-2.5 flex w-full justify-between">
|
<div class=" mb-2.5 flex w-full justify-between">
|
||||||
<div class=" self-center text-xs font-medium">
|
<div class=" self-center text-xs font-medium">
|
||||||
{$i18n.t('Embedding Batch Size')}
|
{$i18n.t('Embedding Batch Size')}
|
||||||
|
Loading…
Reference in New Issue
Block a user