mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 13:40:55 +00:00
feat: tools file handler support
This commit is contained in:
parent
d6ab954f81
commit
a2e1ea103c
@ -241,6 +241,12 @@ async def get_function_call_response(
|
||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
file_handler = False
|
||||
# check if toolkit_module has file_handler self variable
|
||||
if hasattr(toolkit_module, "file_handler"):
|
||||
file_handler = True
|
||||
print("file_handler: ", file_handler)
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
@ -279,12 +285,12 @@ async def get_function_call_response(
|
||||
print(e)
|
||||
|
||||
# Add the function result to the system prompt
|
||||
if function_result:
|
||||
return function_result
|
||||
if function_result is not None:
|
||||
return function_result, file_handler
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
return None
|
||||
return None, False
|
||||
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
context = ""
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
|
||||
skip_files = False
|
||||
if "tool_ids" in data:
|
||||
print(data["tool_ids"])
|
||||
for tool_id in data["tool_ids"]:
|
||||
print(tool_id)
|
||||
try:
|
||||
response = await get_function_call_response(
|
||||
response, file_handler = await get_function_call_response(
|
||||
messages=data["messages"],
|
||||
files=data.get("files", []),
|
||||
tool_id=tool_id,
|
||||
@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
user=user,
|
||||
)
|
||||
|
||||
print(file_handler)
|
||||
if isinstance(response, str):
|
||||
context += ("\n" if context != "" else "") + response
|
||||
|
||||
if file_handler:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del data["tool_ids"]
|
||||
|
||||
print(f"tool_context: {context}")
|
||||
|
||||
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
|
||||
# If files field is present, generate RAG completions
|
||||
# If skip_files is True, skip the RAG completions
|
||||
if "files" in data:
|
||||
data = {**data}
|
||||
rag_context, citations = get_rag_context(
|
||||
files=data["files"],
|
||||
messages=data["messages"],
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
if not skip_files:
|
||||
data = {**data}
|
||||
rag_context, citations = get_rag_context(
|
||||
files=data["files"],
|
||||
messages=data["messages"],
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
else:
|
||||
return_citations = False
|
||||
|
||||
del data["files"]
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
|
||||
if context != "":
|
||||
system_prompt = rag_template(
|
||||
@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
try:
|
||||
context = await get_function_call_response(
|
||||
context, file_handler = await get_function_call_response(
|
||||
form_data["messages"],
|
||||
form_data.get("files", []),
|
||||
form_data["tool_id"],
|
||||
|
Loading…
Reference in New Issue
Block a user