feat: tools file handler support

This commit is contained in:
Timothy J. Baek 2024-06-18 16:45:03 -07:00
parent d6ab954f81
commit a2e1ea103c

View File

@ -241,6 +241,12 @@ async def get_function_call_response(
toolkit_module = load_toolkit_module_by_id(tool_id) toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
try: try:
@ -279,12 +285,12 @@ async def get_function_call_response(
print(e) print(e)
# Add the function result to the system prompt # Add the function result to the system prompt
if function_result: if function_result is not None:
return function_result return function_result, file_handler
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return None return None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context = "" context = ""
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
skip_files = False
if "tool_ids" in data: if "tool_ids" in data:
print(data["tool_ids"]) print(data["tool_ids"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
try: try:
response = await get_function_call_response( response, file_handler = await get_function_call_response(
messages=data["messages"], messages=data["messages"],
files=data.get("files", []), files=data.get("files", []),
tool_id=tool_id, tool_id=tool_id,
@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user=user, user=user,
) )
print(file_handler)
if isinstance(response, str): if isinstance(response, str):
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
if file_handler:
skip_files = True
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
del data["tool_ids"] del data["tool_ids"]
print(f"tool_context: {context}") print(f"tool_context: {context}")
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
# If files field is present, generate RAG completions # If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if "files" in data: if "files" in data:
data = {**data} if not skip_files:
rag_context, citations = get_rag_context( data = {**data}
files=data["files"], rag_context, citations = get_rag_context(
messages=data["messages"], files=data["files"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION, messages=data["messages"],
k=rag_app.state.config.TOP_K, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
reranking_function=rag_app.state.sentence_transformer_rf, k=rag_app.state.config.TOP_K,
r=rag_app.state.config.RELEVANCE_THRESHOLD, reranking_function=rag_app.state.sentence_transformer_rf,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, r=rag_app.state.config.RELEVANCE_THRESHOLD,
) hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
if rag_context: log.debug(f"rag_context: {rag_context}, citations: {citations}")
context += ("\n" if context != "" else "") + rag_context else:
return_citations = False
del data["files"] del data["files"]
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "": if context != "":
system_prompt = rag_template( system_prompt = rag_template(
@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try: try:
context = await get_function_call_response( context, file_handler = await get_function_call_response(
form_data["messages"], form_data["messages"],
form_data.get("files", []), form_data.get("files", []),
form_data["tool_id"], form_data["tool_id"],