mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 02:07:15 +00:00
Feature: adjusted to handle individual rag config
This commit is contained in:
parent
4189459ae2
commit
b7023eb564
@ -549,7 +549,7 @@ async def chat_image_generation_handler(
|
|||||||
|
|
||||||
|
|
||||||
async def chat_completion_files_handler(
|
async def chat_completion_files_handler(
|
||||||
request: Request, body: dict, user: UserModel
|
request: Request, body: dict, user: UserModel, model_knowledge
|
||||||
) -> tuple[dict, dict[str, list]]:
|
) -> tuple[dict, dict[str, list]]:
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
@ -587,6 +587,20 @@ async def chat_completion_files_handler(
|
|||||||
queries = [get_last_user_message(body["messages"])]
|
queries = [get_last_user_message(body["messages"])]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# check if individual rag config is used
|
||||||
|
rag_config = {}
|
||||||
|
if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||||
|
rag_config = model_knowledge[0].get("rag_config")
|
||||||
|
|
||||||
|
k=rag_config.get("TOP_K", request.app.state.config.TOP_K)
|
||||||
|
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
|
||||||
|
reranking_function=request.app.state.rf[reranking_model] if reranking_model else None
|
||||||
|
k_reranker=rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
|
||||||
|
r=rag_config.get("RELEVANCE THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
||||||
|
hybrid_search=rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
|
||||||
|
full_context=rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT)
|
||||||
|
embedding_model = rag_config.get("RAG_EMBEDDING_MODEL", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||||
|
|
||||||
# Offload get_sources_from_files to a separate thread
|
# Offload get_sources_from_files to a separate thread
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
@ -596,15 +610,15 @@ async def chat_completion_files_handler(
|
|||||||
request=request,
|
request=request,
|
||||||
files=files,
|
files=files,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
||||||
query, prefix=prefix, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=request.app.state.config.TOP_K,
|
k=k,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=reranking_function,
|
||||||
k_reranker=request.app.state.config.TOP_K_RERANKER,
|
k_reranker=k_reranker,
|
||||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
r=r,
|
||||||
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
hybrid_search=hybrid_search,
|
||||||
full_context=request.app.state.config.RAG_FULL_CONTEXT,
|
full_context=full_context,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -862,7 +876,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
form_data, flags = await chat_completion_files_handler(request, form_data, user)
|
form_data, flags = await chat_completion_files_handler(request, form_data, user, model_knowledge)
|
||||||
sources.extend(flags.get("sources", []))
|
sources.extend(flags.get("sources", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
@ -898,20 +912,24 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Adjusted RAG template step to use knowledge-base-specific configuration
|
||||||
|
rag_template_config = request.app.state.config.RAG_TEMPLATE
|
||||||
|
|
||||||
|
if model_knowledge and not model_knowledge[0].get("rag_config").get("DEFAULT_RAG_SETTINGS", True):
|
||||||
|
rag_template_config = model_knowledge[0].get("rag_config").get(
|
||||||
|
"RAG_TEMPLATE", request.app.state.config.RAG_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
# Workaround for Ollama 2.0+ system prompt issue
|
# Workaround for Ollama 2.0+ system prompt issue
|
||||||
# TODO: replace with add_or_update_system_message
|
# TODO: replace with add_or_update_system_message
|
||||||
if model.get("owned_by") == "ollama":
|
if model.get("owned_by") == "ollama":
|
||||||
form_data["messages"] = prepend_to_first_user_message_content(
|
form_data["messages"] = prepend_to_first_user_message_content(
|
||||||
rag_template(
|
rag_template(rag_template_config, context_string, prompt),
|
||||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
form_data["messages"],
|
form_data["messages"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
form_data["messages"] = add_or_update_system_message(
|
form_data["messages"] = add_or_update_system_message(
|
||||||
rag_template(
|
rag_template(rag_template_config, context_string, prompt),
|
||||||
request.app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
form_data["messages"],
|
form_data["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user