refac: ollama setting for rag

This commit is contained in:
Timothy Jaeryang Baek 2024-11-18 14:19:56 -08:00
parent e3485d2d88
commit 20321e5271
3 changed files with 130 additions and 56 deletions

View File

@ -75,6 +75,8 @@ from open_webui.config import (
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
RAG_RELEVANCE_THRESHOLD,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
@ -163,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
@ -261,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
@ -312,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
"ollama_config": {
"url": app.state.config.OLLAMA_BASE_URL,
"key": app.state.config.OLLAMA_API_KEY,
},
}
@ -328,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
key: str
class OllamaConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
ollama_config: Optional[OllamaConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@ -350,6 +373,11 @@ async def update_embedding_config(
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
if form_data.ollama_config is not None:
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@ -358,8 +386,16 @@ async def update_embedding_config(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
@ -372,6 +408,10 @@ async def update_embedding_config(
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
"ollama_config": {
"url": app.state.config.OLLAMA_BASE_URL,
"key": app.state.config.OLLAMA_API_KEY,
},
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
@ -785,8 +825,16 @@ def save_docs_to_vector_db(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)

View File

@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
@ -285,25 +280,19 @@ def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
url,
key,
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
# Wrapper to run the async generate_embeddings synchronously.
def sync_generate_embeddings(*args, **kwargs):
return asyncio.run(generate_embeddings(*args, **kwargs))
# Semantic expectation from the original version (using sync wrapper).
func = lambda query: sync_generate_embeddings(
func = lambda query: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
url=url,
key=key,
)
def generate_multiple(query, func):
@ -476,8 +465,8 @@ def get_model_path(model: str, update_model: bool = False):
return model
async def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
def generate_openai_batch_embeddings(
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@ -499,31 +488,50 @@ async def generate_openai_batch_embeddings(
return None
async def generate_embeddings(
engine: str, model: str, text: Union[str, list[str]], **kwargs
):
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
print(data)
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
if engine == "ollama":
if isinstance(text, list):
embeddings = await generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text})
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": text, "url": url, "key": key}
)
else:
embeddings = await generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]})
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": [text], "url": url, "key": key}
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
embeddings = await generate_openai_batch_embeddings(model, text, key, url)
embeddings = generate_openai_batch_embeddings(model, text, url, key)
else:
embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
return embeddings[0] if isinstance(text, str) else embeddings

View File

@ -56,8 +56,11 @@
let chunkOverlap = 0;
let pdfExtractImages = true;
let OpenAIKey = '';
let OpenAIUrl = '';
let OpenAIKey = '';
let OllamaUrl = '';
let OllamaKey = '';
let querySettings = {
template: '',
@ -104,19 +107,15 @@
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_model: embeddingModel,
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
? {
embedding_batch_size: embeddingBatchSize
}
: {}),
...(embeddingEngine === 'openai'
? {
openai_config: {
key: OpenAIKey,
url: OpenAIUrl
}
}
: {})
embedding_batch_size: embeddingBatchSize,
ollama_config: {
key: OllamaKey,
url: OllamaUrl
},
openai_config: {
key: OpenAIKey,
url: OpenAIUrl
}
}).catch(async (error) => {
toast.error(error);
await setEmbeddingConfig();
@ -206,6 +205,9 @@
OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url;
OllamaKey = embeddingConfig.ollama_config.key;
OllamaUrl = embeddingConfig.ollama_config.url;
}
};
@ -310,7 +312,7 @@
</div>
{#if embeddingEngine === 'openai'}
<div class="my-0.5 flex gap-2">
<div class="my-0.5 flex gap-2 pr-2">
<input
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
placeholder={$i18n.t('API Base URL')}
@ -320,7 +322,23 @@
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
</div>
{:else if embeddingEngine === 'ollama'}
<div class="my-0.5 flex gap-2 pr-2">
<input
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={OllamaUrl}
required
/>
<SensitiveInput
placeholder={$i18n.t('API Key')}
bind:value={OllamaKey}
required={false}
/>
</div>
{/if}
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
<div class="flex mt-0.5 space-x-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>