feat: tool citation

This commit is contained in:
Timothy J. Baek 2024-06-20 14:14:12 -07:00
parent 58ae91369e
commit 6bb2f41812

View File

@ -247,6 +247,7 @@ async def get_function_call_response(
result = json.loads(content)
print(result)
citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
@ -309,22 +310,32 @@ async def get_function_call_response(
}
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, file_handler
return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
return None, False
return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
data_items = []
show_citations = False
citations = []
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
@ -342,6 +353,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
@ -365,8 +379,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if getattr(function_module, "file_handler"):
skip_files = True
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
try:
if hasattr(function_module, "inlet"):
@ -411,19 +425,25 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response, file_handler = await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
response, citation, file_handler = (
await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
)
print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response
if citation:
citations.append(citation)
show_citations = True
if file_handler:
skip_files = True
@ -438,7 +458,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if "files" in data:
if not skip_files:
data = {**data}
rag_context, citations = get_rag_context(
rag_context, rag_citations = get_rag_context(
files=data["files"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
@ -452,13 +472,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if citations and data.get("citations"):
data_items.append({"citations": citations})
if rag_citations:
citations.extend(rag_citations)
del data["files"]
if data.get("citations"):
del data["citations"]
if show_citations and len(citations) > 0:
data_items.append({"citations": citations})
if context != "":
system_prompt = rag_template(
@ -1285,7 +1305,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context, file_handler = await get_function_call_response(
context, citation, file_handler = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],