refac: ollama setting for rag

This commit is contained in:
Timothy Jaeryang Baek 2024-11-18 14:19:56 -08:00
parent e3485d2d88
commit 20321e5271
3 changed files with 130 additions and 56 deletions

View File

@ -75,6 +75,8 @@ from open_webui.config import (
RAG_FILE_MAX_SIZE, RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
RAG_RELEVANCE_THRESHOLD, RAG_RELEVANCE_THRESHOLD,
RAG_RERANKING_MODEL, RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
@ -163,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
@ -261,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY, (
app.state.config.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE, app.state.config.RAG_EMBEDDING_BATCH_SIZE,
) )
@ -312,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"url": app.state.config.OPENAI_API_BASE_URL, "url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY, "key": app.state.config.OPENAI_API_KEY,
}, },
"ollama_config": {
"url": app.state.config.OLLAMA_BASE_URL,
"key": app.state.config.OLLAMA_API_KEY,
},
} }
@ -328,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
key: str key: str
class OllamaConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel): class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None openai_config: Optional[OpenAIConfigForm] = None
ollama_config: Optional[OllamaConfigForm] = None
embedding_engine: str embedding_engine: str
embedding_model: str embedding_model: str
embedding_batch_size: Optional[int] = 1 embedding_batch_size: Optional[int] = 1
@ -350,6 +373,11 @@ async def update_embedding_config(
if form_data.openai_config is not None: if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key app.state.config.OPENAI_API_KEY = form_data.openai_config.key
if form_data.ollama_config is not None:
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@ -358,8 +386,16 @@ async def update_embedding_config(
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY, (
app.state.config.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE, app.state.config.RAG_EMBEDDING_BATCH_SIZE,
) )
@ -372,6 +408,10 @@ async def update_embedding_config(
"url": app.state.config.OPENAI_API_BASE_URL, "url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY, "key": app.state.config.OPENAI_API_KEY,
}, },
"ollama_config": {
"url": app.state.config.OLLAMA_BASE_URL,
"key": app.state.config.OLLAMA_API_KEY,
},
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating embedding model: {e}") log.exception(f"Problem updating embedding model: {e}")
@ -785,8 +825,16 @@ def save_docs_to_vector_db(
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY, (
app.state.config.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE, app.state.config.RAG_EMBEDDING_BATCH_SIZE,
) )

View File

@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbedForm,
generate_ollama_batch_embeddings,
)
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message from open_webui.utils.misc import get_last_user_message
@ -285,25 +280,19 @@ def get_embedding_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
embedding_function, embedding_function,
openai_key, url,
openai_url, key,
embedding_batch_size, embedding_batch_size,
): ):
if embedding_engine == "": if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist() return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]: elif embedding_engine in ["ollama", "openai"]:
func = lambda query: generate_embeddings(
# Wrapper to run the async generate_embeddings synchronously.
def sync_generate_embeddings(*args, **kwargs):
return asyncio.run(generate_embeddings(*args, **kwargs))
# Semantic expectation from the original version (using sync wrapper).
func = lambda query: sync_generate_embeddings(
engine=embedding_engine, engine=embedding_engine,
model=embedding_model, model=embedding_model,
text=query, text=query,
key=openai_key if embedding_engine == "openai" else "", url=url,
url=openai_url if embedding_engine == "openai" else "", key=key,
) )
def generate_multiple(query, func): def generate_multiple(query, func):
@ -476,8 +465,8 @@ def get_model_path(model: str, update_model: bool = False):
return model return model
async def generate_openai_batch_embeddings( def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
try: try:
r = requests.post( r = requests.post(
@ -499,31 +488,50 @@ async def generate_openai_batch_embeddings(
return None return None
async def generate_embeddings( def generate_ollama_batch_embeddings(
engine: str, model: str, text: Union[str, list[str]], **kwargs model: str, texts: list[str], url: str, key: str
): ) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
print(data)
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
if engine == "ollama": if engine == "ollama":
if isinstance(text, list): if isinstance(text, list):
embeddings = await generate_ollama_batch_embeddings( embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text}) **{"model": model, "texts": text, "url": url, "key": key}
) )
else: else:
embeddings = await generate_ollama_batch_embeddings( embeddings = generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]}) **{"model": model, "texts": [text], "url": url, "key": key}
) )
return ( return embeddings[0] if isinstance(text, str) else embeddings
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
elif engine == "openai": elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list): if isinstance(text, list):
embeddings = await generate_openai_batch_embeddings(model, text, key, url) embeddings = generate_openai_batch_embeddings(model, text, url, key)
else: else:
embeddings = await generate_openai_batch_embeddings(model, [text], key, url) embeddings = generate_openai_batch_embeddings(model, [text], url, key)
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings

View File

@ -56,8 +56,11 @@
let chunkOverlap = 0; let chunkOverlap = 0;
let pdfExtractImages = true; let pdfExtractImages = true;
let OpenAIKey = '';
let OpenAIUrl = ''; let OpenAIUrl = '';
let OpenAIKey = '';
let OllamaUrl = '';
let OllamaKey = '';
let querySettings = { let querySettings = {
template: '', template: '',
@ -104,19 +107,15 @@
const res = await updateEmbeddingConfig(localStorage.token, { const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine, embedding_engine: embeddingEngine,
embedding_model: embeddingModel, embedding_model: embeddingModel,
...(embeddingEngine === 'openai' || embeddingEngine === 'ollama' embedding_batch_size: embeddingBatchSize,
? { ollama_config: {
embedding_batch_size: embeddingBatchSize key: OllamaKey,
} url: OllamaUrl
: {}), },
...(embeddingEngine === 'openai' openai_config: {
? { key: OpenAIKey,
openai_config: { url: OpenAIUrl
key: OpenAIKey, }
url: OpenAIUrl
}
}
: {})
}).catch(async (error) => { }).catch(async (error) => {
toast.error(error); toast.error(error);
await setEmbeddingConfig(); await setEmbeddingConfig();
@ -206,6 +205,9 @@
OpenAIKey = embeddingConfig.openai_config.key; OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url; OpenAIUrl = embeddingConfig.openai_config.url;
OllamaKey = embeddingConfig.ollama_config.key;
OllamaUrl = embeddingConfig.ollama_config.url;
} }
}; };
@ -310,7 +312,7 @@
</div> </div>
{#if embeddingEngine === 'openai'} {#if embeddingEngine === 'openai'}
<div class="my-0.5 flex gap-2"> <div class="my-0.5 flex gap-2 pr-2">
<input <input
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none" class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
placeholder={$i18n.t('API Base URL')} placeholder={$i18n.t('API Base URL')}
@ -320,7 +322,23 @@
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} /> <SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
</div> </div>
{:else if embeddingEngine === 'ollama'}
<div class="my-0.5 flex gap-2 pr-2">
<input
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={OllamaUrl}
required
/>
<SensitiveInput
placeholder={$i18n.t('API Key')}
bind:value={OllamaKey}
required={false}
/>
</div>
{/if} {/if}
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'} {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
<div class="flex mt-0.5 space-x-2"> <div class="flex mt-0.5 space-x-2">
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div> <div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>