Merge pull request #4815 from michaelpoluektov/fix-user-valves

fix: Fix user valves
This commit is contained in:
Timothy Jaeryang Baek 2024-08-22 15:25:45 +02:00 committed by GitHub
commit 99db82a161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,12 +56,15 @@ from apps.socket.main import get_event_call, get_event_emitter
import inspect import inspect
import json import json
import logging
from typing import Iterator, Generator, AsyncGenerator from typing import Iterator, Generator, AsyncGenerator
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI() app = FastAPI()
log = logging.getLogger(__name__)
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
@ -243,44 +246,37 @@ def get_pipe_id(form_data: dict) -> str:
return pipe_id return pipe_id
def get_function_params(function_module, form_data, user, extra_params={}): def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function_module.pipe) sig = inspect.signature(function_module.pipe)
params = {"body": form_data} params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
for key, value in extra_params.items(): }
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try: try:
if hasattr(function_module, "UserValves"): params["__user__"]["valves"] = function_module.UserValves(**user_valves)
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e: except Exception as e:
print(e) log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
params["__user__"] = __user__
return params return params
async def generate_function_chat_completion(form_data, user): async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {}) metadata = form_data.pop("metadata", {})
files = metadata.get("files", []) files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", []) tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None # Check if tool_ids is None
if tool_ids is None: if tool_ids is None:
tool_ids = [] tool_ids = []
@ -299,18 +295,25 @@ async def generate_function_chat_completion(form_data, user):
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
"__event_call__": __event_call__, "__event_call__": __event_call__,
"__task__": __task__, "__task__": __task__,
} "__user__": {
"id": user.id,
extra_params["__tools__"] = get_tools( "email": user.email,
app, "name": user.name,
tool_ids, "role": user.role,
user,
{
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
}, },
}
extra_params["__tools__"] = (
get_tools(
app,
tool_ids,
user,
{
**extra_params,
"__model__": app.state.MODELS[form_data["model"]],
"__messages__": form_data["messages"],
"__files__": files,
},
),
) )
if model_info: if model_info: