Merge pull request #4815 from michaelpoluektov/fix-user-valves

fix: Fix user valves
This commit is contained in:
Timothy Jaeryang Baek 2024-08-22 15:25:45 +02:00 committed by GitHub
commit 99db82a161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,12 +56,15 @@ from apps.socket.main import get_event_call, get_event_emitter
import inspect
import json
import logging
from typing import Iterator, Generator, AsyncGenerator
from pydantic import BaseModel
app = FastAPI()
log = logging.getLogger(__name__)
app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
@ -243,44 +246,37 @@ def get_pipe_id(form_data: dict) -> str:
return pipe_id
def get_function_params(function_module, form_data, user, extra_params={}):
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data}
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
print(e)
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
params["__user__"] = __user__
return params
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
@ -299,18 +295,25 @@ async def generate_function_chat_completion(form_data, user):
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
}
extra_params["__tools__"] = get_tools(
app,
tool_ids,
user,
{
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
extra_params["__tools__"] = (
get_tools(
app,
tool_ids,
user,
{
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
},
),
)
if model_info: