mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
feat: tools file handler support
This commit is contained in:
parent
d6ab954f81
commit
a2e1ea103c
@ -241,6 +241,12 @@ async def get_function_call_response(
|
|||||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||||
|
|
||||||
|
file_handler = False
|
||||||
|
# check if toolkit_module has file_handler self variable
|
||||||
|
if hasattr(toolkit_module, "file_handler"):
|
||||||
|
file_handler = True
|
||||||
|
print("file_handler: ", file_handler)
|
||||||
|
|
||||||
function = getattr(toolkit_module, result["name"])
|
function = getattr(toolkit_module, result["name"])
|
||||||
function_result = None
|
function_result = None
|
||||||
try:
|
try:
|
||||||
@ -279,12 +285,12 @@ async def get_function_call_response(
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
# Add the function result to the system prompt
|
# Add the function result to the system prompt
|
||||||
if function_result:
|
if function_result is not None:
|
||||||
return function_result
|
return function_result, file_handler
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|
||||||
return None
|
return None, False
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
context = ""
|
context = ""
|
||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
|
|
||||||
|
skip_files = False
|
||||||
if "tool_ids" in data:
|
if "tool_ids" in data:
|
||||||
print(data["tool_ids"])
|
print(data["tool_ids"])
|
||||||
for tool_id in data["tool_ids"]:
|
for tool_id in data["tool_ids"]:
|
||||||
print(tool_id)
|
print(tool_id)
|
||||||
try:
|
try:
|
||||||
response = await get_function_call_response(
|
response, file_handler = await get_function_call_response(
|
||||||
messages=data["messages"],
|
messages=data["messages"],
|
||||||
files=data.get("files", []),
|
files=data.get("files", []),
|
||||||
tool_id=tool_id,
|
tool_id=tool_id,
|
||||||
@ -354,18 +362,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(file_handler)
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
context += ("\n" if context != "" else "") + response
|
context += ("\n" if context != "" else "") + response
|
||||||
|
|
||||||
|
if file_handler:
|
||||||
|
skip_files = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
del data["tool_ids"]
|
del data["tool_ids"]
|
||||||
|
|
||||||
print(f"tool_context: {context}")
|
print(f"tool_context: {context}")
|
||||||
|
|
||||||
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
|
|
||||||
# If files field is present, generate RAG completions
|
# If files field is present, generate RAG completions
|
||||||
|
# If skip_files is True, skip the RAG completions
|
||||||
if "files" in data:
|
if "files" in data:
|
||||||
|
if not skip_files:
|
||||||
data = {**data}
|
data = {**data}
|
||||||
rag_context, citations = get_rag_context(
|
rag_context, citations = get_rag_context(
|
||||||
files=data["files"],
|
files=data["files"],
|
||||||
@ -376,12 +389,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
)
|
)
|
||||||
|
|
||||||
if rag_context:
|
if rag_context:
|
||||||
context += ("\n" if context != "" else "") + rag_context
|
context += ("\n" if context != "" else "") + rag_context
|
||||||
|
|
||||||
del data["files"]
|
|
||||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||||
|
else:
|
||||||
|
return_citations = False
|
||||||
|
|
||||||
|
del data["files"]
|
||||||
|
|
||||||
if context != "":
|
if context != "":
|
||||||
system_prompt = rag_template(
|
system_prompt = rag_template(
|
||||||
@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|||||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = await get_function_call_response(
|
context, file_handler = await get_function_call_response(
|
||||||
form_data["messages"],
|
form_data["messages"],
|
||||||
form_data.get("files", []),
|
form_data.get("files", []),
|
||||||
form_data["tool_id"],
|
form_data["tool_id"],
|
||||||
|
Loading…
Reference in New Issue
Block a user