This commit is contained in:
Timothy J. Baek 2024-06-23 19:18:13 -07:00
parent 26e735618e
commit 0cf936f9e8
5 changed files with 36 additions and 8 deletions

View File

@ -143,10 +143,10 @@ class FunctionsTable:
for function in Function.select().where(Function.type == type)
]
def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]:
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try:
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
return function.valves if "valves" in function and function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None

View File

@ -114,10 +114,10 @@ class ToolsTable:
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]:
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
return tool.valves if "valves" in tool and tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None

View File

@ -127,8 +127,8 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
try:
function_valves = Functions.get_function_valves_by_id(id)
return function_valves.valves
valves = Functions.get_function_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -133,8 +133,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
tool_valves = Tools.get_tool_valves_by_id(id)
return tool_valves.valves
valves = Tools.get_tool_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -262,6 +262,13 @@ async def get_function_call_response(
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
toolkit_module.valves = toolkit_module.Valves(
**Tools.get_tool_valves_by_id(tool_id)
)
function = getattr(toolkit_module, result["name"])
function_result = None
try:
@ -402,6 +409,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
function_module.valves = function_module.Valves(
**Functions.get_function_valves_by_id(filter_id)
)
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
@ -884,6 +898,13 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
else:
function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
function_module.valves = function_module.Valves(
**Functions.get_function_valves_by_id(pipe_id)
)
pipe = function_module.pipe
# Get the signature of the function
@ -1105,6 +1126,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
function_module.valves = function_module.Valves(
**Functions.get_function_valves_by_id(filter_id)
)
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet