From e6c64282fc920897c89e98c11f0ada475bed9e61 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 2 Aug 2024 01:45:50 +0200 Subject: [PATCH] refac --- backend/apps/webui/main.py | 53 +++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 96adb5080..972562a04 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -239,10 +239,10 @@ def get_pipe_id(form_data: dict) -> str: return pipe_id -def get_params_dict(pipe, form_data, user, extra_params, function_module): +def get_function_params(function_module, form_data, user, extra_params={}): pipe_id = get_pipe_id(form_data) # Get the signature of the function - sig = inspect.signature(pipe) + sig = inspect.signature(function_module.pipe) params = {"body": form_data} for key, value in extra_params.items(): @@ -269,26 +269,8 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): return params -def get_extra_params(metadata: dict): - __event_emitter__ = None - __event_call__ = None - __task__ = None - - if metadata: - if all(k in metadata for k in ("session_id", "chat_id", "message_id")): - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) - __task__ = metadata.get("task", None) - - return { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - } - - # inplace function: form_data is modified -def add_model_params(params: dict, form_data: dict) -> dict: +def apply_model_params_to_body(params: dict, form_data: dict) -> dict: if not params: return form_data @@ -309,7 +291,7 @@ def add_model_params(params: dict, form_data: dict) -> dict: # inplace function: form_data is modified -def populate_system_message(params: dict, form_data: dict, user) -> dict: +def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: system = params.get("system", None) if not system: return form_data @@ -333,21 +315,38 @@ async def generate_function_chat_completion(form_data, user): model_info = Models.get_model_by_id(model_id) metadata = form_data.pop("metadata", None) - # Add extra params such as __event_emitter__ - extra_params = get_extra_params(metadata) + __event_emitter__ = None + __event_call__ = None + __task__ = None + + if metadata: + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) + if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = add_model_params(params, form_data) - form_data = populate_system_message(params, form_data, user) + form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_params_dict(pipe, form_data, user, extra_params, function_module) + params = get_function_params( + function_module, + form_data, + user, + { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + }, + ) if form_data["stream"]: