diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 8de2a04cb..87258fb2c 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -75,6 +75,8 @@ from open_webui.config import ( RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, RAG_RELEVANCE_THRESHOLD, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, @@ -163,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE @@ -261,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -312,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)): "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, + "ollama_config": { + "url": app.state.config.OLLAMA_BASE_URL, + "key": app.state.config.OLLAMA_API_KEY, + }, } @@ -328,8 +345,14 @@ class OpenAIConfigForm(BaseModel): key: str +class OllamaConfigForm(BaseModel): + url: str + key: str + + class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None + ollama_config: Optional[OllamaConfigForm] = None embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 @@ -350,6 +373,11 @@ async def update_embedding_config( if form_data.openai_config is not None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key + + if form_data.ollama_config is not None: + app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url + app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key + app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) @@ -358,8 +386,16 @@ async def update_embedding_config( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) @@ -372,6 +408,10 @@ async def update_embedding_config( "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, + "ollama_config": { + "url": app.state.config.OLLAMA_BASE_URL, + "key": app.state.config.OLLAMA_API_KEY, + }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") @@ -785,8 +825,16 @@ def save_docs_to_vector_db( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, + ( + app.state.config.OPENAI_API_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_BASE_URL + ), + ( + app.state.config.OPENAI_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "openai" + else app.state.config.OLLAMA_API_KEY + ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 7d92b7350..77d97814c 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document - -from open_webui.apps.ollama.main import ( - GenerateEmbedForm, - generate_ollama_batch_embeddings, -) from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message @@ -285,25 +280,19 @@ def get_embedding_function( embedding_engine, embedding_model, embedding_function, - openai_key, - openai_url, + url, + key, embedding_batch_size, ): if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: - - # Wrapper to run the async generate_embeddings synchronously. - def sync_generate_embeddings(*args, **kwargs): - return asyncio.run(generate_embeddings(*args, **kwargs)) - - # Semantic expectation from the original version (using sync wrapper). - func = lambda query: sync_generate_embeddings( + func = lambda query: generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, - key=openai_key if embedding_engine == "openai" else "", - url=openai_url if embedding_engine == "openai" else "", + url=url, + key=key, ) def generate_multiple(query, func): @@ -476,8 +465,8 @@ def get_model_path(model: str, update_model: bool = False): return model -async def generate_openai_batch_embeddings( - model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" +def generate_openai_batch_embeddings( + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -499,31 +488,50 @@ async def generate_openai_batch_embeddings( return None -async def generate_embeddings( - engine: str, model: str, text: Union[str, list[str]], **kwargs -): +def generate_ollama_batch_embeddings( + model: str, texts: list[str], url: str, key: str +) -> Optional[list[list[float]]]: + try: + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + }, + json={"input": texts, "model": model}, + ) + r.raise_for_status() + data = r.json() + + print(data) + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" + except Exception as e: + print(e) + return None + + +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): + url = kwargs.get("url", "") + key = kwargs.get("key", "") + if engine == "ollama": if isinstance(text, list): - embeddings = await generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": text}) + embeddings = generate_ollama_batch_embeddings( + **{"model": model, "texts": text, "url": url, "key": key} ) else: - embeddings = await generate_ollama_batch_embeddings( - GenerateEmbedForm(**{"model": model, "input": [text]}) + embeddings = generate_ollama_batch_embeddings( + **{"model": model, "texts": [text], "url": url, "key": key} ) - return ( - embeddings["embeddings"][0] - if isinstance(text, str) - else embeddings["embeddings"] - ) + return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": - key = kwargs.get("key", "") - url = kwargs.get("url", "https://api.openai.com/v1") - if isinstance(text, list): - embeddings = await generate_openai_batch_embeddings(model, text, key, url) + embeddings = generate_openai_batch_embeddings(model, text, url, key) else: - embeddings = await generate_openai_batch_embeddings(model, [text], key, url) + embeddings = generate_openai_batch_embeddings(model, [text], url, key) return embeddings[0] if isinstance(text, str) else embeddings diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 1b2f119ff..a596c293c 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -56,8 +56,11 @@ let chunkOverlap = 0; let pdfExtractImages = true; - let OpenAIKey = ''; let OpenAIUrl = ''; + let OpenAIKey = ''; + + let OllamaUrl = ''; + let OllamaKey = ''; let querySettings = { template: '', @@ -104,19 +107,15 @@ const res = await updateEmbeddingConfig(localStorage.token, { embedding_engine: embeddingEngine, embedding_model: embeddingModel, - ...(embeddingEngine === 'openai' || embeddingEngine === 'ollama' - ? { - embedding_batch_size: embeddingBatchSize - } - : {}), - ...(embeddingEngine === 'openai' - ? { - openai_config: { - key: OpenAIKey, - url: OpenAIUrl - } - } - : {}) + embedding_batch_size: embeddingBatchSize, + ollama_config: { + key: OllamaKey, + url: OllamaUrl + }, + openai_config: { + key: OpenAIKey, + url: OpenAIUrl + } }).catch(async (error) => { toast.error(error); await setEmbeddingConfig(); @@ -206,6 +205,9 @@ OpenAIKey = embeddingConfig.openai_config.key; OpenAIUrl = embeddingConfig.openai_config.url; + + OllamaKey = embeddingConfig.ollama_config.key; + OllamaUrl = embeddingConfig.ollama_config.url; } }; @@ -310,7 +312,7 @@ {#if embeddingEngine === 'openai'} -
+
+ {:else if embeddingEngine === 'ollama'} +
+ + + +
{/if} + {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
{$i18n.t('Embedding Batch Size')}