feat: user hook

This commit is contained in:
Timothy J. Baek 2024-06-11 10:19:59 -07:00
parent f62b15d8da
commit 8a86f32700

View File

@ -11,6 +11,7 @@ import requests
import mimetypes
import shutil
import os
import inspect
import asyncio
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
@ -204,6 +205,8 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
# Parse the function response
if content is not None:
print(content)
result = json.loads(content)
print(result)
@ -218,7 +221,24 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
# Get the signature of the function
sig = inspect.signature(function)
# Check if 'user' is a parameter of the function
if "user" in sig.parameters:
# Call the function with the 'user' parameter included
function_result = function(
**{
**result["parameters"],
"user": {
"id": user.id,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
except Exception as e:
print(e)
@ -284,6 +304,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
prompt = get_last_user_message(data["messages"])
for tool_id in data["tool_ids"]:
print(tool_id)
@ -299,7 +320,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context += ("\n" if context != "" else "") + response
del data["tool_ids"]
print(context)
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
if "docs" in data: