mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
refac: tools & rag
This commit is contained in:
parent
fc46532955
commit
bd5a8567ef
@ -236,10 +236,9 @@ def get_embedding_function(
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
|
||||
def rag_messages(
|
||||
def get_rag_context(
|
||||
docs,
|
||||
messages,
|
||||
template,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
@ -318,16 +317,7 @@ def rag_messages(
|
||||
|
||||
context_string = context_string.strip()
|
||||
|
||||
ra_content = rag_template(
|
||||
template=template,
|
||||
context=context_string,
|
||||
query=query,
|
||||
)
|
||||
|
||||
log.debug(f"ra_content: {ra_content}")
|
||||
messages = add_or_update_system_message(ra_content, messages)
|
||||
|
||||
return messages, citations
|
||||
return context_string, citations
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
|
@ -64,7 +64,7 @@ from utils.task import (
|
||||
)
|
||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||
|
||||
from apps.rag.utils import rag_messages, rag_template
|
||||
from apps.rag.utils import get_rag_context, rag_template
|
||||
|
||||
from config import (
|
||||
CONFIG_DATA,
|
||||
@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
)
|
||||
|
||||
# Remove the citations from the body
|
||||
return_citations = data.get("citations", False)
|
||||
if "citations" in data:
|
||||
@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
|
||||
if "tool_ids" in data:
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
)
|
||||
prompt = get_last_user_message(data["messages"])
|
||||
context = ""
|
||||
context = ""
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
if "tool_ids" in data:
|
||||
prompt = get_last_user_message(data["messages"])
|
||||
for tool_id in data["tool_ids"]:
|
||||
print(tool_id)
|
||||
response = await get_function_call_response(
|
||||
@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
if response:
|
||||
context += ("\n" if context != "" else "") + response
|
||||
|
||||
if context != "":
|
||||
system_prompt = rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||
)
|
||||
|
||||
print(system_prompt)
|
||||
|
||||
data["messages"] = add_or_update_system_message(
|
||||
f"\n{system_prompt}", data["messages"]
|
||||
)
|
||||
|
||||
del data["tool_ids"]
|
||||
|
||||
# If docs field is present, generate RAG completions
|
||||
if "docs" in data:
|
||||
data = {**data}
|
||||
data["messages"], citations = rag_messages(
|
||||
rag_context, citations = get_rag_context(
|
||||
docs=data["docs"],
|
||||
messages=data["messages"],
|
||||
template=rag_app.state.config.RAG_TEMPLATE,
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
|
||||
del data["docs"]
|
||||
|
||||
log.debug(
|
||||
f"data['messages']: {data['messages']}, citations: {citations}"
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
|
||||
if context != "":
|
||||
system_prompt = rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||
)
|
||||
|
||||
print(system_prompt)
|
||||
|
||||
data["messages"] = add_or_update_system_message(
|
||||
f"\n{system_prompt}", data["messages"]
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
|
Loading…
Reference in New Issue
Block a user