feat: tools file handler support

This commit is contained in:
Timothy J. Baek 2024-06-18 16:45:03 -07:00
parent d6ab954f81
commit a2e1ea103c
1 changed files with 34 additions and 19 deletions

View File

@ -241,6 +241,12 @@ async def get_function_call_response(
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
function = getattr(toolkit_module, result["name"])
function_result = None
try:
@ -279,12 +285,12 @@ async def get_function_call_response(
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
if function_result is not None:
return function_result, file_handler
except Exception as e:
print(f"Error: {e}")
return None
return None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware):
@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context = ""
# If tool_ids field is present, call the functions
skip_files = False
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response = await get_function_call_response(
response, file_handler = await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user=user,
)
print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if "files" in data:
data = {**data}
rag_context, citations = get_rag_context(
files=data["files"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if not skip_files:
data = {**data}
rag_context, citations = get_rag_context(
files=data["files"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
else:
return_citations = False
del data["files"]
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "":
system_prompt = rag_template(
@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context = await get_function_call_response(
context, file_handler = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],