mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
refac: tools & rag
This commit is contained in:
parent
fc46532955
commit
bd5a8567ef
@ -236,10 +236,9 @@ def get_embedding_function(
|
|||||||
return lambda query: generate_multiple(query, func)
|
return lambda query: generate_multiple(query, func)
|
||||||
|
|
||||||
|
|
||||||
def rag_messages(
|
def get_rag_context(
|
||||||
docs,
|
docs,
|
||||||
messages,
|
messages,
|
||||||
template,
|
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k,
|
k,
|
||||||
reranking_function,
|
reranking_function,
|
||||||
@ -318,16 +317,7 @@ def rag_messages(
|
|||||||
|
|
||||||
context_string = context_string.strip()
|
context_string = context_string.strip()
|
||||||
|
|
||||||
ra_content = rag_template(
|
return context_string, citations
|
||||||
template=template,
|
|
||||||
context=context_string,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.debug(f"ra_content: {ra_content}")
|
|
||||||
messages = add_or_update_system_message(ra_content, messages)
|
|
||||||
|
|
||||||
return messages, citations
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model: str, update_model: bool = False):
|
def get_model_path(model: str, update_model: bool = False):
|
||||||
|
@ -64,7 +64,7 @@ from utils.task import (
|
|||||||
)
|
)
|
||||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||||
|
|
||||||
from apps.rag.utils import rag_messages, rag_template
|
from apps.rag.utils import get_rag_context, rag_template
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
CONFIG_DATA,
|
CONFIG_DATA,
|
||||||
@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
# Parse string to JSON
|
# Parse string to JSON
|
||||||
data = json.loads(body_str) if body_str else {}
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
user = get_current_user(
|
||||||
|
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||||
|
)
|
||||||
|
|
||||||
# Remove the citations from the body
|
# Remove the citations from the body
|
||||||
return_citations = data.get("citations", False)
|
return_citations = data.get("citations", False)
|
||||||
if "citations" in data:
|
if "citations" in data:
|
||||||
@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
):
|
):
|
||||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
|
||||||
if "tool_ids" in data:
|
|
||||||
user = get_current_user(
|
|
||||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
|
||||||
)
|
|
||||||
prompt = get_last_user_message(data["messages"])
|
|
||||||
context = ""
|
context = ""
|
||||||
|
|
||||||
|
# If tool_ids field is present, call the functions
|
||||||
|
if "tool_ids" in data:
|
||||||
|
prompt = get_last_user_message(data["messages"])
|
||||||
for tool_id in data["tool_ids"]:
|
for tool_id in data["tool_ids"]:
|
||||||
print(tool_id)
|
print(tool_id)
|
||||||
response = await get_function_call_response(
|
response = await get_function_call_response(
|
||||||
@ -295,6 +297,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
if response:
|
if response:
|
||||||
context += ("\n" if context != "" else "") + response
|
context += ("\n" if context != "" else "") + response
|
||||||
|
del data["tool_ids"]
|
||||||
|
|
||||||
|
# If docs field is present, generate RAG completions
|
||||||
|
if "docs" in data:
|
||||||
|
data = {**data}
|
||||||
|
rag_context, citations = get_rag_context(
|
||||||
|
docs=data["docs"],
|
||||||
|
messages=data["messages"],
|
||||||
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||||
|
k=rag_app.state.config.TOP_K,
|
||||||
|
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||||
|
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||||
|
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rag_context:
|
||||||
|
context += ("\n" if context != "" else "") + rag_context
|
||||||
|
|
||||||
|
del data["docs"]
|
||||||
|
|
||||||
|
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||||
|
|
||||||
if context != "":
|
if context != "":
|
||||||
system_prompt = rag_template(
|
system_prompt = rag_template(
|
||||||
@ -307,27 +330,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
f"\n{system_prompt}", data["messages"]
|
f"\n{system_prompt}", data["messages"]
|
||||||
)
|
)
|
||||||
|
|
||||||
del data["tool_ids"]
|
|
||||||
|
|
||||||
# If docs field is present, generate RAG completions
|
|
||||||
if "docs" in data:
|
|
||||||
data = {**data}
|
|
||||||
data["messages"], citations = rag_messages(
|
|
||||||
docs=data["docs"],
|
|
||||||
messages=data["messages"],
|
|
||||||
template=rag_app.state.config.RAG_TEMPLATE,
|
|
||||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
||||||
k=rag_app.state.config.TOP_K,
|
|
||||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
|
||||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
||||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
||||||
)
|
|
||||||
del data["docs"]
|
|
||||||
|
|
||||||
log.debug(
|
|
||||||
f"data['messages']: {data['messages']}, citations: {citations}"
|
|
||||||
)
|
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
|
|
||||||
# Replace the request body with the modified one
|
# Replace the request body with the modified one
|
||||||
|
Loading…
Reference in New Issue
Block a user