refac: "rag" endpoints renamed to "retrieval"

This commit is contained in:
Timothy J. Baek 2024-09-28 01:27:46 +02:00
parent 6e9db3e3c8
commit e1103305f5
27 changed files with 41 additions and 34 deletions

View File

@ -1061,7 +1061,7 @@ def store_data_in_vector_db(
if len(docs) > 0: if len(docs) > 0:
log.info(f"store_data_in_vector_db {docs}") log.info(f"store_data_in_vector_db {docs}")
return store_docs_in_vector_db(docs, collection_name, metadata, overwrite), None return store_docs_in_vector_db(docs, collection_name, metadata, overwrite)
else: else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@ -1377,6 +1377,7 @@ def process_doc(
) )
if result: if result:
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,

View File

@ -16,37 +16,45 @@ from typing import Optional
import aiohttp import aiohttp
import requests import requests
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import ( from open_webui.apps.ollama.main import (
GenerateChatCompletionForm, app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion, generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion, generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm,
) )
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import ( from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion, generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
) )
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.rag.utils import get_rag_context, rag_template from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter from open_webui.apps.socket.main import (
from open_webui.apps.webui.internal.db import Session app as socket_app,
from open_webui.apps.webui.main import app as webui_app periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.main import ( from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion, generate_function_chat_completion,
get_pipe_models, get_pipe_models,
) )
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
@ -491,11 +499,11 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
files=files, files=files,
messages=body["messages"], messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION, embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K, k=retrieval_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf, reranking_function=retrieval_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD, r=retrieval_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
) )
log.debug(f"rag_contexts: {contexts}, citations: {citations}") log.debug(f"rag_contexts: {contexts}, citations: {citations}")
@ -608,7 +616,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if prompt is None: if prompt is None:
raise Exception("No user message found") raise Exception("No user message found")
if ( if (
rag_app.state.config.RELEVANCE_THRESHOLD == 0 retrieval_app.state.config.RELEVANCE_THRESHOLD == 0
and context_string.strip() == "" and context_string.strip() == ""
): ):
log.debug( log.debug(
@ -620,14 +628,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content( body["messages"] = prepend_to_first_user_message_content(
rag_template( rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
), ),
body["messages"], body["messages"],
) )
else: else:
body["messages"] = add_or_update_system_message( body["messages"] = add_or_update_system_message(
rag_template( rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
), ),
body["messages"], body["messages"],
) )
@ -849,7 +857,7 @@ async def check_url(request: Request, call_next):
async def update_embedding_function(request: Request, call_next): async def update_embedding_function(request: Request, call_next):
response = await call_next(request) response = await call_next(request)
if "/embedding/update" in request.url.path: if "/embedding/update" in request.url.path:
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
return response return response
@ -877,11 +885,12 @@ app.mount("/openai", openai_app)
app.mount("/images/api/v1", images_app) app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app) app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app) app.mount("/retrieval/api/v1", retrieval_app)
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
async def get_all_models(): async def get_all_models():
@ -2066,7 +2075,7 @@ async def get_app_config(request: Request):
"enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM, "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM,
**( **(
{ {
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED, "enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING, "enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING,
@ -2092,8 +2101,8 @@ async def get_app_config(request: Request):
}, },
}, },
"file": { "file": {
"max_size": rag_app.state.config.FILE_MAX_SIZE, "max_size": retrieval_app.state.config.FILE_MAX_SIZE,
"max_count": rag_app.state.config.FILE_MAX_COUNT, "max_count": retrieval_app.state.config.FILE_MAX_COUNT,
}, },
"permissions": {**webui_app.state.config.USER_PERMISSIONS}, "permissions": {**webui_app.state.config.USER_PERMISSIONS},
} }

View File

@ -159,16 +159,13 @@
const processFileItem = async (fileItem) => { const processFileItem = async (fileItem) => {
try { try {
const res = await processDocToVectorDB(localStorage.token, fileItem.id); const res = await processDocToVectorDB(localStorage.token, fileItem.id);
if (res) { if (res) {
fileItem.status = 'processed'; fileItem.status = 'processed';
fileItem.collection_name = res.collection_name; fileItem.collection_name = res.collection_name;
files = files; files = files;
} }
} catch (e) { } catch (e) {
// Remove the failed doc from the files array // We keep the file in the files list even if it fails to process
// files = files.filter((f) => f.id !== fileItem.id);
toast.error(e);
fileItem.status = 'processed'; fileItem.status = 'processed';
files = files; files = files;
} }

View File

@ -11,7 +11,7 @@ export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama`;
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai`;
export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`; export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;
export const IMAGES_API_BASE_URL = `${WEBUI_BASE_URL}/images/api/v1`; export const IMAGES_API_BASE_URL = `${WEBUI_BASE_URL}/images/api/v1`;
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`; export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/retrieval/api/v1`;
export const WEBUI_VERSION = APP_VERSION; export const WEBUI_VERSION = APP_VERSION;
export const WEBUI_BUILD_HASH = APP_BUILD_HASH; export const WEBUI_BUILD_HASH = APP_BUILD_HASH;