mirror of
https://github.com/open-webui/open-webui
synced 2025-06-16 19:31:52 +00:00
refac: ollama setting for rag
This commit is contained in:
parent
e3485d2d88
commit
20321e5271
@ -75,6 +75,8 @@ from open_webui.config import (
|
|||||||
RAG_FILE_MAX_SIZE,
|
RAG_FILE_MAX_SIZE,
|
||||||
RAG_OPENAI_API_BASE_URL,
|
RAG_OPENAI_API_BASE_URL,
|
||||||
RAG_OPENAI_API_KEY,
|
RAG_OPENAI_API_KEY,
|
||||||
|
RAG_OLLAMA_BASE_URL,
|
||||||
|
RAG_OLLAMA_API_KEY,
|
||||||
RAG_RELEVANCE_THRESHOLD,
|
RAG_RELEVANCE_THRESHOLD,
|
||||||
RAG_RERANKING_MODEL,
|
RAG_RERANKING_MODEL,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
@ -163,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
|||||||
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||||
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||||
|
|
||||||
|
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
||||||
|
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
||||||
|
|
||||||
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
||||||
|
|
||||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
||||||
@ -261,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
|||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.config.OPENAI_API_KEY,
|
(
|
||||||
app.state.config.OPENAI_API_BASE_URL,
|
app.state.config.OPENAI_API_BASE_URL
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
|
else app.state.config.OLLAMA_BASE_URL
|
||||||
|
),
|
||||||
|
(
|
||||||
|
app.state.config.OPENAI_API_KEY
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
|
else app.state.config.OLLAMA_API_KEY
|
||||||
|
),
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -312,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|||||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||||
"key": app.state.config.OPENAI_API_KEY,
|
"key": app.state.config.OPENAI_API_KEY,
|
||||||
},
|
},
|
||||||
|
"ollama_config": {
|
||||||
|
"url": app.state.config.OLLAMA_BASE_URL,
|
||||||
|
"key": app.state.config.OLLAMA_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -328,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
|
|||||||
key: str
|
key: str
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaConfigForm(BaseModel):
|
||||||
|
url: str
|
||||||
|
key: str
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelUpdateForm(BaseModel):
|
class EmbeddingModelUpdateForm(BaseModel):
|
||||||
openai_config: Optional[OpenAIConfigForm] = None
|
openai_config: Optional[OpenAIConfigForm] = None
|
||||||
|
ollama_config: Optional[OllamaConfigForm] = None
|
||||||
embedding_engine: str
|
embedding_engine: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_batch_size: Optional[int] = 1
|
embedding_batch_size: Optional[int] = 1
|
||||||
@ -350,6 +373,11 @@ async def update_embedding_config(
|
|||||||
if form_data.openai_config is not None:
|
if form_data.openai_config is not None:
|
||||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||||
|
|
||||||
|
if form_data.ollama_config is not None:
|
||||||
|
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
|
||||||
|
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
|
||||||
|
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||||
|
|
||||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||||
@ -358,8 +386,16 @@ async def update_embedding_config(
|
|||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.config.OPENAI_API_KEY,
|
(
|
||||||
app.state.config.OPENAI_API_BASE_URL,
|
app.state.config.OPENAI_API_BASE_URL
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
|
else app.state.config.OLLAMA_BASE_URL
|
||||||
|
),
|
||||||
|
(
|
||||||
|
app.state.config.OPENAI_API_KEY
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
|
else app.state.config.OLLAMA_API_KEY
|
||||||
|
),
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -372,6 +408,10 @@ async def update_embedding_config(
|
|||||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||||
"key": app.state.config.OPENAI_API_KEY,
|
"key": app.state.config.OPENAI_API_KEY,
|
||||||
},
|
},
|
||||||
|
"ollama_config": {
|
||||||
|
"url": app.state.config.OLLAMA_BASE_URL,
|
||||||
|
"key": app.state.config.OLLAMA_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Problem updating embedding model: {e}")
|
log.exception(f"Problem updating embedding model: {e}")
|
||||||
@ -785,8 +825,16 @@ def save_docs_to_vector_db(
|
|||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
app.state.config.OPENAI_API_KEY,
|
(
|
||||||
app.state.config.OPENAI_API_BASE_URL,
|
app.state.config.OPENAI_API_BASE_URL
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
|
else app.state.config.OLLAMA_BASE_URL
|
||||||
|
),
|
||||||
|
(
|
||||||
|
app.state.config.OPENAI_API_KEY
|
||||||
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
|
else app.state.config.OLLAMA_API_KEY
|
||||||
|
),
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
|
|||||||
from langchain_community.retrievers import BM25Retriever
|
from langchain_community.retrievers import BM25Retriever
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
|
||||||
from open_webui.apps.ollama.main import (
|
|
||||||
GenerateEmbedForm,
|
|
||||||
generate_ollama_batch_embeddings,
|
|
||||||
)
|
|
||||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||||
from open_webui.utils.misc import get_last_user_message
|
from open_webui.utils.misc import get_last_user_message
|
||||||
|
|
||||||
@ -285,25 +280,19 @@ def get_embedding_function(
|
|||||||
embedding_engine,
|
embedding_engine,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
openai_key,
|
url,
|
||||||
openai_url,
|
key,
|
||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
):
|
):
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
return lambda query: embedding_function.encode(query).tolist()
|
return lambda query: embedding_function.encode(query).tolist()
|
||||||
elif embedding_engine in ["ollama", "openai"]:
|
elif embedding_engine in ["ollama", "openai"]:
|
||||||
|
func = lambda query: generate_embeddings(
|
||||||
# Wrapper to run the async generate_embeddings synchronously.
|
|
||||||
def sync_generate_embeddings(*args, **kwargs):
|
|
||||||
return asyncio.run(generate_embeddings(*args, **kwargs))
|
|
||||||
|
|
||||||
# Semantic expectation from the original version (using sync wrapper).
|
|
||||||
func = lambda query: sync_generate_embeddings(
|
|
||||||
engine=embedding_engine,
|
engine=embedding_engine,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
text=query,
|
text=query,
|
||||||
key=openai_key if embedding_engine == "openai" else "",
|
url=url,
|
||||||
url=openai_url if embedding_engine == "openai" else "",
|
key=key,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_multiple(query, func):
|
def generate_multiple(query, func):
|
||||||
@ -476,8 +465,8 @@ def get_model_path(model: str, update_model: bool = False):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def generate_openai_batch_embeddings(
|
def generate_openai_batch_embeddings(
|
||||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
@ -499,31 +488,50 @@ async def generate_openai_batch_embeddings(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def generate_embeddings(
|
def generate_ollama_batch_embeddings(
|
||||||
engine: str, model: str, text: Union[str, list[str]], **kwargs
|
model: str, texts: list[str], url: str, key: str
|
||||||
):
|
) -> Optional[list[list[float]]]:
|
||||||
|
try:
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/api/embed",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
},
|
||||||
|
json={"input": texts, "model": model},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
|
||||||
|
print(data)
|
||||||
|
if "embeddings" in data:
|
||||||
|
return data["embeddings"]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||||
|
url = kwargs.get("url", "")
|
||||||
|
key = kwargs.get("key", "")
|
||||||
|
|
||||||
if engine == "ollama":
|
if engine == "ollama":
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
embeddings = await generate_ollama_batch_embeddings(
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
GenerateEmbedForm(**{"model": model, "input": text})
|
**{"model": model, "texts": text, "url": url, "key": key}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = await generate_ollama_batch_embeddings(
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
GenerateEmbedForm(**{"model": model, "input": [text]})
|
**{"model": model, "texts": [text], "url": url, "key": key}
|
||||||
)
|
)
|
||||||
return (
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
embeddings["embeddings"][0]
|
|
||||||
if isinstance(text, str)
|
|
||||||
else embeddings["embeddings"]
|
|
||||||
)
|
|
||||||
elif engine == "openai":
|
elif engine == "openai":
|
||||||
key = kwargs.get("key", "")
|
|
||||||
url = kwargs.get("url", "https://api.openai.com/v1")
|
|
||||||
|
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
embeddings = await generate_openai_batch_embeddings(model, text, key, url)
|
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
||||||
else:
|
else:
|
||||||
embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
|
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
||||||
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
|
||||||
|
@ -56,8 +56,11 @@
|
|||||||
let chunkOverlap = 0;
|
let chunkOverlap = 0;
|
||||||
let pdfExtractImages = true;
|
let pdfExtractImages = true;
|
||||||
|
|
||||||
let OpenAIKey = '';
|
|
||||||
let OpenAIUrl = '';
|
let OpenAIUrl = '';
|
||||||
|
let OpenAIKey = '';
|
||||||
|
|
||||||
|
let OllamaUrl = '';
|
||||||
|
let OllamaKey = '';
|
||||||
|
|
||||||
let querySettings = {
|
let querySettings = {
|
||||||
template: '',
|
template: '',
|
||||||
@ -104,19 +107,15 @@
|
|||||||
const res = await updateEmbeddingConfig(localStorage.token, {
|
const res = await updateEmbeddingConfig(localStorage.token, {
|
||||||
embedding_engine: embeddingEngine,
|
embedding_engine: embeddingEngine,
|
||||||
embedding_model: embeddingModel,
|
embedding_model: embeddingModel,
|
||||||
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
|
embedding_batch_size: embeddingBatchSize,
|
||||||
? {
|
ollama_config: {
|
||||||
embedding_batch_size: embeddingBatchSize
|
key: OllamaKey,
|
||||||
}
|
url: OllamaUrl
|
||||||
: {}),
|
},
|
||||||
...(embeddingEngine === 'openai'
|
openai_config: {
|
||||||
? {
|
key: OpenAIKey,
|
||||||
openai_config: {
|
url: OpenAIUrl
|
||||||
key: OpenAIKey,
|
}
|
||||||
url: OpenAIUrl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
: {})
|
|
||||||
}).catch(async (error) => {
|
}).catch(async (error) => {
|
||||||
toast.error(error);
|
toast.error(error);
|
||||||
await setEmbeddingConfig();
|
await setEmbeddingConfig();
|
||||||
@ -206,6 +205,9 @@
|
|||||||
|
|
||||||
OpenAIKey = embeddingConfig.openai_config.key;
|
OpenAIKey = embeddingConfig.openai_config.key;
|
||||||
OpenAIUrl = embeddingConfig.openai_config.url;
|
OpenAIUrl = embeddingConfig.openai_config.url;
|
||||||
|
|
||||||
|
OllamaKey = embeddingConfig.ollama_config.key;
|
||||||
|
OllamaUrl = embeddingConfig.ollama_config.url;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -310,7 +312,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if embeddingEngine === 'openai'}
|
{#if embeddingEngine === 'openai'}
|
||||||
<div class="my-0.5 flex gap-2">
|
<div class="my-0.5 flex gap-2 pr-2">
|
||||||
<input
|
<input
|
||||||
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
|
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
|
||||||
placeholder={$i18n.t('API Base URL')}
|
placeholder={$i18n.t('API Base URL')}
|
||||||
@ -320,7 +322,23 @@
|
|||||||
|
|
||||||
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
||||||
</div>
|
</div>
|
||||||
|
{:else if embeddingEngine === 'ollama'}
|
||||||
|
<div class="my-0.5 flex gap-2 pr-2">
|
||||||
|
<input
|
||||||
|
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
|
||||||
|
placeholder={$i18n.t('API Base URL')}
|
||||||
|
bind:value={OllamaUrl}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SensitiveInput
|
||||||
|
placeholder={$i18n.t('API Key')}
|
||||||
|
bind:value={OllamaKey}
|
||||||
|
required={false}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
|
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
|
||||||
<div class="flex mt-0.5 space-x-2">
|
<div class="flex mt-0.5 space-x-2">
|
||||||
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
|
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
|
||||||
|
Loading…
Reference in New Issue
Block a user