mirror of
https://github.com/open-webui/open-webui
synced 2025-02-21 04:17:26 +00:00
feat: tool citation
This commit is contained in:
parent
58ae91369e
commit
6bb2f41812
@ -247,6 +247,7 @@ async def get_function_call_response(
|
|||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
citation = None
|
||||||
# Call the function
|
# Call the function
|
||||||
if "name" in result:
|
if "name" in result:
|
||||||
if tool_id in webui_app.state.TOOLS:
|
if tool_id in webui_app.state.TOOLS:
|
||||||
@ -309,22 +310,32 @@ async def get_function_call_response(
|
|||||||
}
|
}
|
||||||
|
|
||||||
function_result = function(**params)
|
function_result = function(**params)
|
||||||
|
|
||||||
|
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
||||||
|
citation = {
|
||||||
|
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
|
||||||
|
"document": [function_result],
|
||||||
|
"metadata": [{"source": result["name"]}],
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
# Add the function result to the system prompt
|
# Add the function result to the system prompt
|
||||||
if function_result is not None:
|
if function_result is not None:
|
||||||
return function_result, file_handler
|
return function_result, citation, file_handler
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|
||||||
return None, False
|
return None, None, False
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
data_items = []
|
data_items = []
|
||||||
|
|
||||||
|
show_citations = False
|
||||||
|
citations = []
|
||||||
|
|
||||||
if request.method == "POST" and any(
|
if request.method == "POST" and any(
|
||||||
endpoint in request.url.path
|
endpoint in request.url.path
|
||||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||||
@ -342,6 +353,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
# Flag to skip RAG completions if file_handler is present in tools/functions
|
# Flag to skip RAG completions if file_handler is present in tools/functions
|
||||||
skip_files = False
|
skip_files = False
|
||||||
|
if data.get("citations"):
|
||||||
|
show_citations = True
|
||||||
|
del data["citations"]
|
||||||
|
|
||||||
model_id = data["model"]
|
model_id = data["model"]
|
||||||
if model_id not in app.state.MODELS:
|
if model_id not in app.state.MODELS:
|
||||||
@ -365,8 +379,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||||
|
|
||||||
# Check if the function has a file_handler variable
|
# Check if the function has a file_handler variable
|
||||||
if getattr(function_module, "file_handler"):
|
if hasattr(function_module, "file_handler"):
|
||||||
skip_files = True
|
skip_files = function_module.file_handler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(function_module, "inlet"):
|
if hasattr(function_module, "inlet"):
|
||||||
@ -411,7 +425,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
for tool_id in data["tool_ids"]:
|
for tool_id in data["tool_ids"]:
|
||||||
print(tool_id)
|
print(tool_id)
|
||||||
try:
|
try:
|
||||||
response, file_handler = await get_function_call_response(
|
response, citation, file_handler = (
|
||||||
|
await get_function_call_response(
|
||||||
messages=data["messages"],
|
messages=data["messages"],
|
||||||
files=data.get("files", []),
|
files=data.get("files", []),
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
@ -419,11 +434,16 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
task_model_id=task_model_id,
|
task_model_id=task_model_id,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
print(file_handler)
|
print(file_handler)
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
context += ("\n" if context != "" else "") + response
|
context += ("\n" if context != "" else "") + response
|
||||||
|
|
||||||
|
if citation:
|
||||||
|
citations.append(citation)
|
||||||
|
show_citations = True
|
||||||
|
|
||||||
if file_handler:
|
if file_handler:
|
||||||
skip_files = True
|
skip_files = True
|
||||||
|
|
||||||
@ -438,7 +458,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
if "files" in data:
|
if "files" in data:
|
||||||
if not skip_files:
|
if not skip_files:
|
||||||
data = {**data}
|
data = {**data}
|
||||||
rag_context, citations = get_rag_context(
|
rag_context, rag_citations = get_rag_context(
|
||||||
files=data["files"],
|
files=data["files"],
|
||||||
messages=data["messages"],
|
messages=data["messages"],
|
||||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||||
@ -452,13 +472,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||||
|
|
||||||
if citations and data.get("citations"):
|
if rag_citations:
|
||||||
data_items.append({"citations": citations})
|
citations.extend(rag_citations)
|
||||||
|
|
||||||
del data["files"]
|
del data["files"]
|
||||||
|
|
||||||
if data.get("citations"):
|
if show_citations and len(citations) > 0:
|
||||||
del data["citations"]
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
if context != "":
|
if context != "":
|
||||||
system_prompt = rag_template(
|
system_prompt = rag_template(
|
||||||
@ -1285,7 +1305,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|||||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context, file_handler = await get_function_call_response(
|
context, citation, file_handler = await get_function_call_response(
|
||||||
form_data["messages"],
|
form_data["messages"],
|
||||||
form_data.get("files", []),
|
form_data.get("files", []),
|
||||||
form_data["tool_id"],
|
form_data["tool_id"],
|
||||||
|
Loading…
Reference in New Issue
Block a user