mirror of
https://github.com/open-webui/open-webui
synced 2024-11-22 08:07:55 +00:00
refac: ollama setting for rag
This commit is contained in:
parent
e3485d2d88
commit
20321e5271
@ -75,6 +75,8 @@ from open_webui.config import (
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
RAG_OPENAI_API_KEY,
|
||||
RAG_OLLAMA_BASE_URL,
|
||||
RAG_OLLAMA_API_KEY,
|
||||
RAG_RELEVANCE_THRESHOLD,
|
||||
RAG_RERANKING_MODEL,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
@ -163,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||
|
||||
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
||||
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
||||
|
||||
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
||||
|
||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
||||
@ -261,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
(
|
||||
app.state.config.OPENAI_API_BASE_URL
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else app.state.config.OLLAMA_BASE_URL
|
||||
),
|
||||
(
|
||||
app.state.config.OPENAI_API_KEY
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else app.state.config.OLLAMA_API_KEY
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
@ -312,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
},
|
||||
"ollama_config": {
|
||||
"url": app.state.config.OLLAMA_BASE_URL,
|
||||
"key": app.state.config.OLLAMA_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -328,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
|
||||
key: str
|
||||
|
||||
|
||||
class OllamaConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
ollama_config: Optional[OllamaConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
@ -350,6 +373,11 @@ async def update_embedding_config(
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
|
||||
if form_data.ollama_config is not None:
|
||||
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
|
||||
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
|
||||
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
@ -358,8 +386,16 @@ async def update_embedding_config(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
(
|
||||
app.state.config.OPENAI_API_BASE_URL
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else app.state.config.OLLAMA_BASE_URL
|
||||
),
|
||||
(
|
||||
app.state.config.OPENAI_API_KEY
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else app.state.config.OLLAMA_API_KEY
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
@ -372,6 +408,10 @@ async def update_embedding_config(
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
},
|
||||
"ollama_config": {
|
||||
"url": app.state.config.OLLAMA_BASE_URL,
|
||||
"key": app.state.config.OLLAMA_API_KEY,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
@ -785,8 +825,16 @@ def save_docs_to_vector_db(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
(
|
||||
app.state.config.OPENAI_API_BASE_URL
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else app.state.config.OLLAMA_BASE_URL
|
||||
),
|
||||
(
|
||||
app.state.config.OPENAI_API_KEY
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else app.state.config.OLLAMA_API_KEY
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
)
|
||||
|
||||
|
@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbedForm,
|
||||
generate_ollama_batch_embeddings,
|
||||
)
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
|
||||
@ -285,25 +280,19 @@ def get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
url,
|
||||
key,
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
|
||||
# Wrapper to run the async generate_embeddings synchronously.
|
||||
def sync_generate_embeddings(*args, **kwargs):
|
||||
return asyncio.run(generate_embeddings(*args, **kwargs))
|
||||
|
||||
# Semantic expectation from the original version (using sync wrapper).
|
||||
func = lambda query: sync_generate_embeddings(
|
||||
func = lambda query: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key if embedding_engine == "openai" else "",
|
||||
url=openai_url if embedding_engine == "openai" else "",
|
||||
url=url,
|
||||
key=key,
|
||||
)
|
||||
|
||||
def generate_multiple(query, func):
|
||||
@ -476,8 +465,8 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
return model
|
||||
|
||||
|
||||
async def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
@ -499,31 +488,50 @@ async def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
async def generate_embeddings(
|
||||
engine: str, model: str, text: Union[str, list[str]], **kwargs
|
||||
):
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str, texts: list[str], url: str, key: str
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/api/embed",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
print(data)
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
url = kwargs.get("url", "")
|
||||
key = kwargs.get("key", "")
|
||||
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = await generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": text})
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": text, "url": url, "key": key}
|
||||
)
|
||||
else:
|
||||
embeddings = await generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": [text]})
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": [text], "url": url, "key": key}
|
||||
)
|
||||
return (
|
||||
embeddings["embeddings"][0]
|
||||
if isinstance(text, str)
|
||||
else embeddings["embeddings"]
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
key = kwargs.get("key", "")
|
||||
url = kwargs.get("url", "https://api.openai.com/v1")
|
||||
|
||||
if isinstance(text, list):
|
||||
embeddings = await generate_openai_batch_embeddings(model, text, key, url)
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
||||
else:
|
||||
embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
@ -56,8 +56,11 @@
|
||||
let chunkOverlap = 0;
|
||||
let pdfExtractImages = true;
|
||||
|
||||
let OpenAIKey = '';
|
||||
let OpenAIUrl = '';
|
||||
let OpenAIKey = '';
|
||||
|
||||
let OllamaUrl = '';
|
||||
let OllamaKey = '';
|
||||
|
||||
let querySettings = {
|
||||
template: '',
|
||||
@ -104,19 +107,15 @@
|
||||
const res = await updateEmbeddingConfig(localStorage.token, {
|
||||
embedding_engine: embeddingEngine,
|
||||
embedding_model: embeddingModel,
|
||||
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
|
||||
? {
|
||||
embedding_batch_size: embeddingBatchSize
|
||||
}
|
||||
: {}),
|
||||
...(embeddingEngine === 'openai'
|
||||
? {
|
||||
openai_config: {
|
||||
key: OpenAIKey,
|
||||
url: OpenAIUrl
|
||||
}
|
||||
}
|
||||
: {})
|
||||
embedding_batch_size: embeddingBatchSize,
|
||||
ollama_config: {
|
||||
key: OllamaKey,
|
||||
url: OllamaUrl
|
||||
},
|
||||
openai_config: {
|
||||
key: OpenAIKey,
|
||||
url: OpenAIUrl
|
||||
}
|
||||
}).catch(async (error) => {
|
||||
toast.error(error);
|
||||
await setEmbeddingConfig();
|
||||
@ -206,6 +205,9 @@
|
||||
|
||||
OpenAIKey = embeddingConfig.openai_config.key;
|
||||
OpenAIUrl = embeddingConfig.openai_config.url;
|
||||
|
||||
OllamaKey = embeddingConfig.ollama_config.key;
|
||||
OllamaUrl = embeddingConfig.ollama_config.url;
|
||||
}
|
||||
};
|
||||
|
||||
@ -310,7 +312,7 @@
|
||||
</div>
|
||||
|
||||
{#if embeddingEngine === 'openai'}
|
||||
<div class="my-0.5 flex gap-2">
|
||||
<div class="my-0.5 flex gap-2 pr-2">
|
||||
<input
|
||||
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
|
||||
placeholder={$i18n.t('API Base URL')}
|
||||
@ -320,7 +322,23 @@
|
||||
|
||||
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
|
||||
</div>
|
||||
{:else if embeddingEngine === 'ollama'}
|
||||
<div class="my-0.5 flex gap-2 pr-2">
|
||||
<input
|
||||
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
|
||||
placeholder={$i18n.t('API Base URL')}
|
||||
bind:value={OllamaUrl}
|
||||
required
|
||||
/>
|
||||
|
||||
<SensitiveInput
|
||||
placeholder={$i18n.t('API Key')}
|
||||
bind:value={OllamaKey}
|
||||
required={false}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
|
||||
<div class="flex mt-0.5 space-x-2">
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
|
||||
|
Loading…
Reference in New Issue
Block a user