fix user valves

This commit is contained in:
Michael Poluektov 2024-08-22 13:34:35 +01:00
parent 14f0e6a2ba
commit 16ec25d296

View File

@ -56,12 +56,15 @@ from apps.socket.main import get_event_call, get_event_emitter
import inspect import inspect
import json import json
import logging
from typing import Iterator, Generator, AsyncGenerator from typing import Iterator, Generator, AsyncGenerator
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI() app = FastAPI()
log = logging.getLogger(__name__)
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
@ -243,33 +246,23 @@ def get_pipe_id(form_data: dict) -> str:
return pipe_id return pipe_id
def get_function_params(function_module, form_data, user, extra_params={}): def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function_module.pipe) sig = inspect.signature(function_module.pipe)
params = {"body": form_data} addition_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
params = {"body": form_data} | addition_params
for key, value in extra_params.items(): if "__user__" in params and hasattr(function_module, "UserValves"):
if key in sig.parameters: user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
params[key] = value if user_valves:
try:
if "__user__" in sig.parameters: params["__user__"]["valves"] = function_module.UserValves(**user_valves)
__user__ = { except Exception as e:
"id": user.id, log.exception(e)
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params["__user__"] = __user__
return params return params
@ -298,6 +291,12 @@ async def generate_function_chat_completion(form_data, user):
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
"__event_call__": __event_call__, "__event_call__": __event_call__,
"__task__": __task__, "__task__": __task__,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
} }
tools_params = { tools_params = {
**extra_params, **extra_params,